Shortcuts

Source code for connectomics.model.arch.unet

from __future__ import print_function, division
from typing import Optional, List
from collections import OrderedDict

import torch
import math
import torch.nn as nn
import torch.nn.functional as F

from ..block import *
from ..utils import model_init


[docs]class UNet3D(nn.Module): """3D residual U-Net architecture. This design is flexible in handling both isotropic data and anisotropic data. Args: block_type (str): the block type at each U-Net stage. Default: ``'residual'`` in_channel (int): number of input channels. Default: 1 out_channel (int): number of output channels. Default: 3 filters (List[int]): number of filters at each U-Net stage. Default: [28, 36, 48, 64, 80] is_isotropic (bool): whether the whole model is isotropic. Default: False isotropy (List[bool]): specify each U-Net stage is isotropic or anisotropic. All elements will be `True` if :attr:`is_isotropic` is `True`. Default: [False, False, False, True, True] pad_mode (str): one of ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'replicate'`` act_mode (str): one of ``'relu'``, ``'leaky_relu'``, ``'elu'``, ``'gelu'``, ``'swish'``, ``'efficient_swish'`` or ``'none'``. Default: ``'relu'`` norm_mode (str): one of ``'bn'``, ``'sync_bn'`` ``'in'`` or ``'gn'``. Default: ``'bn'`` init_mode (str): one of ``'xavier'``, ``'kaiming'``, ``'selu'`` or ``'orthogonal'``. Default: ``'orthogonal'`` pooling (bool): downsample by max-pooling if `True` else using stride. Default: `False` blurpool (bool): apply blurpool as in Zhang 2019 (https://arxiv.org/abs/1904.11486). Default: `False` """ block_dict = { 'residual': BasicBlock3d, 'residual_pa': BasicBlock3dPA, 'residual_se': BasicBlock3dSE, 'residual_se_pa': BasicBlock3dPASE, } def __init__(self, block_type='residual', in_channel: int = 1, out_channel: int = 3, filters: List[int] = [28, 36, 48, 64, 80], is_isotropic: bool = False, isotropy: List[bool] = [False, False, False, True, True], pad_mode: str = 'replicate', act_mode: str = 'elu', norm_mode: str = 'bn', init_mode: str = 'orthogonal', pooling: bool = False, blurpool: bool = False, return_feats: Optional[list] = None, **kwargs): super().__init__() self.depth = len(filters) self.do_return_feats = (return_feats is not None) self.return_feats = return_feats print(f"Return feature maps from 3D U-Net? {self.do_return_feats}") if is_isotropic: isotropy = [True] * self.depth assert len(filters) == len(isotropy) block = self.block_dict[block_type] self.pooling, self.blurpool = pooling, blurpool self.shared_kwargs = { 'pad_mode': pad_mode, 'act_mode': act_mode, 'norm_mode': norm_mode} # input and output layers kernel_size_io, padding_io = self._get_kernal_size( is_isotropic, io_layer=True) self.conv_in = conv3d_norm_act(in_channel, filters[0], kernel_size_io, padding=padding_io, **self.shared_kwargs) self.conv_out = conv3d_norm_act(filters[0], out_channel, kernel_size_io, bias=True, padding=padding_io, pad_mode=pad_mode, act_mode='none', norm_mode='none') # encoding path self.down_layers = nn.ModuleList() for i in range(self.depth): kernel_size, padding = self._get_kernal_size(isotropy[i]) previous = max(0, i-1) stride = self._get_stride(isotropy[i], previous, i) layer = nn.Sequential( self._make_pooling_layer(isotropy[i], previous, i), conv3d_norm_act(filters[previous], filters[i], kernel_size, stride=stride, padding=padding, **self.shared_kwargs), block(filters[i], filters[i], **self.shared_kwargs)) self.down_layers.append(layer) # decoding path self.up_layers = nn.ModuleList() for j in range(1, self.depth): kernel_size, padding = self._get_kernal_size(isotropy[j]) layer = nn.ModuleList([ conv3d_norm_act(filters[j], filters[j-1], kernel_size, padding=padding, **self.shared_kwargs), block(filters[j-1], filters[j-1], **self.shared_kwargs)]) self.up_layers.append(layer) # initialization model_init(self, mode=init_mode)
[docs] def forward(self, x): x = self.conv_in(x) down_x = [None] * (self.depth-1) for i in range(self.depth-1): x = self.down_layers[i](x) down_x[i] = x x = self.down_layers[-1](x) self._maybe_collect_feat(x, restart=True) # the first one to collect for j in range(self.depth-1): i = self.depth-2-j x = self.up_layers[i][0](x) x = self._upsample_add(x, down_x[i]) x = self.up_layers[i][1](x) self._maybe_collect_feat(x) x = self.conv_out(x) if self.do_return_feats: return x, self.feats return x
def _maybe_collect_feat(self, x, restart: bool=False): """Collect U-Net features at different pyramid levels.""" if not self.do_return_feats: return if restart: self.feats = OrderedDict() self.feat_index = -1 self.feat_index += 1 if self.feat_index in self.return_feats: self.feats[self.feat_index] = x def _upsample_add(self, x, y): """Upsample and add two feature maps. When pooling layer is used, the input size is assumed to be even, therefore :attr:`align_corners` is set to `False` to avoid feature mis-match. When downsampling by stride, the input size is assumed to be 2n+1, and :attr:`align_corners` is set to `True`. """ align_corners = False if self.pooling else True x = F.interpolate(x, size=y.shape[2:], mode='trilinear', align_corners=align_corners) return x + y def _get_kernal_size(self, is_isotropic, io_layer=False): if io_layer: # kernel and padding size of I/O layers if is_isotropic: return (5, 5, 5), (2, 2, 2) return (1, 5, 5), (0, 2, 2) if is_isotropic: return (3, 3, 3), (1, 1, 1) return (1, 3, 3), (0, 1, 1) def _get_stride(self, is_isotropic, previous, i): if self.pooling or previous == i: return 1 return self._get_downsample(is_isotropic) def _get_downsample(self, is_isotropic): if not is_isotropic: return (1, 2, 2) return 2 def _make_pooling_layer(self, is_isotropic, previous, i): if self.pooling and previous != i: kernel_size = stride = self._get_downsample(is_isotropic) return nn.MaxPool3d(kernel_size, stride) return nn.Identity()
[docs]class UNetPlus3D(UNet3D): def __init__(self, filters: List[int] = [28, 36, 48, 64, 80], norm_mode: str = 'bn', **kwargs): super().__init__(filters=filters, norm_mode=norm_mode, **kwargs) self.feat_layers = nn.ModuleList( [conv3d_norm_act(filters[-1], filters[k-1], 1, **self.shared_kwargs) for k in range(1, self.depth)] ) self.non_local = NonLocalBlock3D( filters[-1], sub_sample=False, norm_mode=norm_mode)
[docs] def forward(self, x): x = self.conv_in(x) down_x = [None] * (self.depth-1) for i in range(self.depth-1): x = self.down_layers[i](x) down_x[i] = x x = self.down_layers[-1](x) x = self.non_local(x) feat = x # lowest-res feature map self._maybe_collect_feat(x, restart=True) # the first one to collect for j in range(self.depth-1): i = self.depth-2-j x = self.up_layers[i][0](x) x = self._upsample_add(x, down_x[i]) x = self._upsample_add(self.feat_layers[i](feat), x) x = self.up_layers[i][1](x) self._maybe_collect_feat(x) x = self.conv_out(x) if self.do_return_feats: return x, self.feats return x
[docs]class UNet2D(nn.Module): """2D residual U-Net architecture. Args: block_type (str): the block type at each U-Net stage. Default: ``'residual'`` in_channel (int): number of input channels. Default: 1 out_channel (int): number of output channels. Default: 3 filters (List[int]): number of filters at each U-Net stage. Default: [32, 64, 128, 256, 512] pad_mode (str): one of ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'replicate'`` act_mode (str): one of ``'relu'``, ``'leaky_relu'``, ``'elu'``, ``'gelu'``, ``'swish'``, ``'efficient_swish'`` or ``'none'``. Default: ``'leaky_relu'`` norm_mode (str): one of ``'bn'``, ``'sync_bn'`` ``'in'`` or ``'gn'``. Default: ``'gn'`` init_mode (str): one of ``'xavier'``, ``'kaiming'``, ``'selu'`` or ``'orthogonal'``. Default: ``'orthogonal'`` pooling (bool): downsample by max-pooling if `True` else using stride. Default: `False` """ block_dict = { 'residual': BasicBlock2d, 'residual_se': BasicBlock2dSE, } def __init__(self, block_type='residual', in_channel: int = 1, out_channel: int = 3, filters: List[int] = [32, 64, 128, 256, 512], pad_mode: str = 'replicate', act_mode: str = 'leaky_relu', norm_mode: str = 'gn', init_mode: str = 'orthogonal', pooling: bool = False, **kwargs): super().__init__() self.depth = len(filters) self.pooling = pooling block = self.block_dict[block_type] self.shared_kwargs = { 'pad_mode': pad_mode, 'act_mode': act_mode, 'norm_mode': norm_mode} # input and output layers self.conv_in = conv2d_norm_act( in_channel, filters[0], 5, padding=2, **self.shared_kwargs) self.conv_out = conv2d_norm_act(filters[0], out_channel, 5, padding=2, bias=True, pad_mode=pad_mode, act_mode='none', norm_mode='none') # encoding path self.down_layers = nn.ModuleList() for i in range(self.depth): kernel_size, padding = 3, 1 previous = max(0, i-1) stride = self._get_stride(previous, i) layer = nn.Sequential( self._make_pooling_layer(previous, i), conv2d_norm_act(filters[previous], filters[i], kernel_size, stride=stride, padding=padding, **self.shared_kwargs), block(filters[i], filters[i], **self.shared_kwargs)) self.down_layers.append(layer) # decoding path self.up_layers = nn.ModuleList() for j in range(1, self.depth): kernel_size, padding = 3, 1 layer = nn.ModuleList([ conv2d_norm_act(filters[j], filters[j-1], kernel_size, padding=padding, **self.shared_kwargs), block(filters[j-1], filters[j-1], **self.shared_kwargs)]) self.up_layers.append(layer) # initialization model_init(self, mode=init_mode)
[docs] def forward(self, x): x = self.conv_in(x) down_x = [None] * (self.depth-1) for i in range(self.depth-1): x = self.down_layers[i](x) down_x[i] = x x = self.down_layers[-1](x) for j in range(self.depth-1): i = self.depth-2-j x = self.up_layers[i][0](x) x = self._upsample_add(x, down_x[i]) x = self.up_layers[i][1](x) x = self.conv_out(x) return x
def _upsample_add(self, x, y): """Upsample and add two feature maps. When pooling layer is used, the input size is assumed to be even, therefore :attr:`align_corners` is set to `False` to avoid feature mis-match. When downsampling by stride, the input size is assumed to be 2n+1, and :attr:`align_corners` is set to `False`. """ align_corners = False if self.pooling else True x = F.interpolate(x, size=y.shape[2:], mode='bilinear', align_corners=align_corners) return x + y def _get_stride(self, previous, i): if self.pooling or previous == i: return 1 return 2 def _make_pooling_layer(self, previous, i): if self.pooling and previous != i: kernel_size = stride = 2 return nn.MaxPool2d(kernel_size, stride) return nn.Identity()
[docs]class UNetPlus2D(UNet2D): def __init__(self, filters: List[int] = [32, 64, 128, 256, 512], norm_mode: str = 'gn', **kwargs): super().__init__(filters=filters, norm_mode=norm_mode, **kwargs) self.feat_layers = nn.ModuleList( [conv2d_norm_act(filters[-1], filters[k-1], 1, **self.shared_kwargs) for k in range(1, self.depth)] ) self.non_local = NonLocalBlock2D( filters[-1], sub_sample=False, norm_mode=norm_mode)
[docs] def forward(self, x): x = self.conv_in(x) down_x = [None] * (self.depth-1) for i in range(self.depth-1): x = self.down_layers[i](x) down_x[i] = x x = self.down_layers[-1](x) x = self.non_local(x) feat = x # lowest-res feature map for j in range(self.depth-1): i = self.depth-2-j x = self.up_layers[i][0](x) x = self._upsample_add(x, down_x[i]) x = self._upsample_add(self.feat_layers[i](feat), x) x = self.up_layers[i][1](x) x = self.conv_out(x) return x

© Copyright 2019-2023, PyTorch Connectomics Contributors. Revision 9f6fcf8b.

Built with Sphinx using a theme provided by Read the Docs.