from __future__ import print_function, division
from typing import Optional, List
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from ..backbone import build_backbone
from ..block import conv3d_norm_act
from ..utils import model_init
[docs]class FPN3D(nn.Module):
"""3D feature pyramid network (FPN). This design is flexible in handling both isotropic data and anisotropic data.
Args:
backbone_type (str): the block type at each U-Net stage. Default: ``'resnet'``
block_type (str): the block type in the backbone. Default: ``'residual'``
in_channel (int): number of input channels. Default: 1
out_channel (int): number of output channels. Default: 3
filters (List[int]): number of filters at each FPN stage. Default: [28, 36, 48, 64, 80]
is_isotropic (bool): whether the whole model is isotropic. Default: False
isotropy (List[bool]): specify each U-Net stage is isotropic or anisotropic. All elements will
be `True` if :attr:`is_isotropic` is `True`. Default: [False, False, False, True, True]
pad_mode (str): one of ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'replicate'``
act_mode (str): one of ``'relu'``, ``'leaky_relu'``, ``'elu'``, ``'gelu'``,
``'swish'``, ``'efficient_swish'`` or ``'none'``. Default: ``'relu'``
norm_mode (str): one of ``'bn'``, ``'sync_bn'`` ``'in'`` or ``'gn'``. Default: ``'bn'``
init_mode (str): one of ``'xavier'``, ``'kaiming'``, ``'selu'`` or ``'orthogonal'``. Default: ``'orthogonal'``
deploy (bool): build backbone in deploy mode (exclusive for RepVGG backbone). Default: False
"""
def __init__(self,
backbone_type: str = 'resnet',
block_type: str = 'residual',
feature_keys: List[str] = ['feat1', 'feat2', 'feat3', 'feat4', 'feat5'],
in_channel: int = 1,
out_channel: int = 3,
filters: List[int] = [28, 36, 48, 64, 80],
ks: List[int] = [3, 3, 5, 3, 3],
blocks: List[int] = [2, 2, 2, 2, 2],
attn: str = 'squeeze_excitation',
is_isotropic: bool = False,
isotropy: List[bool] = [False, False, False, True, True],
pad_mode: str = 'replicate',
act_mode: str = 'elu',
norm_mode: str = 'bn',
init_mode: str = 'orthogonal',
deploy: bool = False,
fmap_size=[17, 129, 129],
**kwargs):
super().__init__()
self.filters = filters
self.depth = len(filters)
assert len(isotropy) == self.depth
if is_isotropic:
isotropy = [True] * self.depth
self.isotropy = isotropy
self.shared_kwargs = {
'pad_mode': pad_mode,
'act_mode': act_mode,
'norm_mode': norm_mode
}
backbone_kwargs = {
'block_type': block_type,
'in_channel': in_channel,
'filters': filters,
'isotropy': isotropy,
'blocks': blocks,
'deploy': deploy,
'fmap_size': fmap_size,
'ks': ks,
'attention': attn,
}
backbone_kwargs.update(self.shared_kwargs)
self.backbone = build_backbone(
backbone_type, feature_keys, **backbone_kwargs)
self.feature_keys = feature_keys
self.latplanes = filters[0]
self.latlayers = nn.ModuleList([
conv3d_norm_act(x, self.latplanes, kernel_size=1, padding=0,
**self.shared_kwargs) for x in filters])
self.smooth = nn.ModuleList()
for i in range(self.depth):
kernel_size, padding = self._get_kernel_size(isotropy[i])
self.smooth.append(conv3d_norm_act(
self.latplanes, self.latplanes, kernel_size=kernel_size,
padding=padding, **self.shared_kwargs))
self.conv_out = self._get_io_conv(out_channel, isotropy[0])
# initialization
model_init(self, init_mode)
[docs] def forward(self, x):
z = self.backbone(x)
return self._forward_main(z)
def _forward_main(self, z):
features = [self.latlayers[i](z[self.feature_keys[i]])
for i in range(self.depth)]
out = features[self.depth-1]
for j in range(self.depth-1):
i = self.depth-1-j
out = self._up_smooth_add(out, features[i-1], self.smooth[i])
out = self.smooth[0](out)
out = self.conv_out(out)
return out
def _up_smooth_add(self, x, y, smooth):
"""Upsample, smooth and add two feature maps.
"""
x = F.interpolate(x, size=y.shape[2:], mode='trilinear',
align_corners=True)
return smooth(x) + y
def _get_kernel_size(self, is_isotropic, io_layer=False):
if io_layer: # kernel and padding size of I/O layers
if is_isotropic:
return (5, 5, 5), (2, 2, 2)
return (1, 5, 5), (0, 2, 2)
if is_isotropic:
return (3, 3, 3), (1, 1, 1)
return (1, 3, 3), (0, 1, 1)
def _get_io_conv(self, out_channel, is_isotropic):
kernel_size_io, padding_io = self._get_kernel_size(
is_isotropic, io_layer=True)
return conv3d_norm_act(
self.filters[0], out_channel, kernel_size_io, padding=padding_io,
pad_mode=self.shared_kwargs['pad_mode'], bias=True,
act_mode='none', norm_mode='none')