Source code for clinicadl.networks.nn.resnet

import re
from collections import OrderedDict
from copy import deepcopy
from enum import Enum
from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union

import torch
import torch.nn as nn
from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_act_layer
from pydantic import PositiveInt, model_validator
from torch.hub import load_state_dict_from_url
from torchvision.models.resnet import (
    ResNet18_Weights,
    ResNet34_Weights,
    ResNet50_Weights,
    ResNet101_Weights,
    ResNet152_Weights,
)

from clinicadl.utils.factories import get_defaults_from

from .layers.resnet import ResNetBlock, ResNetBottleneck
from .layers.senet import SEResNetBlock, SEResNetBottleneck
from .layers.utils import ActivationParameters
from .utils import ensure_tuple
from .utils.config import (
    NetworkConfig,
    _SpatialDimsConfig,
)

__all__ = [
    "ResNet",
    "ResNet18",
    "ResNet34",
    "ResNet50",
    "ResNet101",
    "ResNet152",
]


class ResNetBlockType(str, Enum):
    """Supported ResNet blocks."""

    BASIC = "basic"
    BOTTLENECK = "bottleneck"


class GeneralResNet(nn.Module):
    """Common base class for ResNet and SEResNet."""

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        num_outputs: Optional[int],
        block_type: Union[str, ResNetBlockType],
        n_res_blocks: Sequence[int],
        n_features: Sequence[int],
        init_conv_size: Union[Sequence[int], int],
        init_conv_stride: Union[Sequence[int], int],
        bottleneck_reduction: int,
        se_reduction: Optional[int],
        act: ActivationParameters,
        output_act: ActivationParameters,
    ) -> None:
        super().__init__()
        self.squeeze_excitation = True if se_reduction else False
        self.se_reduction = se_reduction
        self.spatial_dims = spatial_dims
        self.n_features = n_features
        self.bottleneck_reduction = bottleneck_reduction

        block, in_planes = self._get_block(block_type)

        conv_type, norm_type, pool_type, avgp_type = self._get_layers()

        block_avgpool = [0, 1, (1, 1), (1, 1, 1)]

        self.in_planes = in_planes[0]
        self.n_layers = len(in_planes)
        self.bias_downsample = False

        self.conv0 = conv_type(  # pylint: disable=not-callable
            in_channels,
            self.in_planes,
            kernel_size=init_conv_size,
            stride=init_conv_stride,
            padding=tuple(k // 2 for k in init_conv_size),
            bias=False,
        )
        self.norm0 = norm_type(self.in_planes)  # pylint: disable=not-callable
        self.act0 = get_act_layer(name=act)
        self.pool0 = pool_type(kernel_size=3, stride=2, padding=1)  # pylint: disable=not-callable
        self.layer1 = self._make_resnet_layer(
            block, in_planes[0], n_res_blocks[0], spatial_dims, act
        )
        for i, (n_blocks, n_feats) in enumerate(
            zip(n_res_blocks[1:], in_planes[1:]), start=2
        ):
            self.add_module(
                f"layer{i}",
                self._make_resnet_layer(
                    block,
                    planes=n_feats,
                    blocks=n_blocks,
                    spatial_dims=spatial_dims,
                    stride=2,
                    act=act,
                ),
            )
        self.fc = (
            nn.Sequential(
                OrderedDict(
                    [
                        ("pool", avgp_type(block_avgpool[spatial_dims])),  # pylint: disable=not-callable
                        ("flatten", nn.Flatten(1)),
                        ("out", nn.Linear(n_features[-1], num_outputs)),
                    ]
                )
            )
            if num_outputs
            else None
        )
        if self.fc:
            self.fc.output_act = get_act_layer(output_act) if output_act else None

        self._init_module(conv_type, norm_type)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv0(x)
        x = self.norm0(x)
        x = self.act0(x)
        x = self.pool0(x)

        for i in range(1, self.n_layers + 1):
            x = self.get_submodule(f"layer{i}")(x)

        if self.fc is not None:
            x = self.fc(x)

        return x

    def _get_block(self, block_type: Union[str, ResNetBlockType]) -> nn.Module:
        """
        Gets the residual block, depending on the block choice made by the user and depending
        on whether squeeze-excitation mode or not.
        """
        block_type = ResNetBlockType(block_type)
        if block_type == ResNetBlockType.BASIC:
            in_planes = self.n_features
            if self.squeeze_excitation:
                block = SEResNetBlock
                block.reduction = self.se_reduction
            else:
                block = ResNetBlock
        elif block_type == ResNetBlockType.BOTTLENECK:
            in_planes = self._bottleneck_reduce()
            if self.squeeze_excitation:
                block = SEResNetBottleneck
                block.reduction = self.se_reduction
            else:
                block = ResNetBottleneck
            block.expansion = self.bottleneck_reduction

        return block, in_planes  # pylint: disable=possibly-used-before-assignment

    def _get_layers(self):
        """
        Gets convolution, normalization, pooling and adaptative average pooling layers.
        """
        conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[
            Conv.CONV, self.spatial_dims
        ]
        norm_type: Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]] = Norm[
            Norm.BATCH, self.spatial_dims
        ]
        pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[
            Pool.MAX, self.spatial_dims
        ]
        avgp_type: Type[
            Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]
        ] = Pool[Pool.ADAPTIVEAVG, self.spatial_dims]

        return conv_type, norm_type, pool_type, avgp_type

    def _make_resnet_layer(
        self,
        block: Type[Union[ResNetBlock, ResNetBottleneck]],
        planes: int,
        blocks: int,
        spatial_dims: int,
        act: ActivationParameters,
        stride: int = 1,
    ) -> nn.Sequential:
        """
        Builds a ResNet layer.
        """
        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]

        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            downsample = nn.Sequential(
                conv_type(  # pylint: disable=not-callable
                    self.in_planes,
                    planes * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=self.bias_downsample,
                ),
                norm_type(planes * block.expansion),  # pylint: disable=not-callable
            )

        layers = [
            block(
                in_planes=self.in_planes,
                planes=planes,
                spatial_dims=spatial_dims,
                stride=stride,
                downsample=downsample,
                act=act,
            )
        ]

        self.in_planes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.in_planes, planes, spatial_dims=spatial_dims, act=act)
            )

        return nn.Sequential(*layers)

    def _init_module(
        self, conv_type: Type[nn.Module], norm_type: Type[nn.Module]
    ) -> None:
        """
        Initializes the parameters.
        """
        for m in self.modules():
            if isinstance(m, conv_type):
                nn.init.kaiming_normal_(
                    torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu"
                )
            elif isinstance(m, norm_type):
                nn.init.constant_(torch.as_tensor(m.weight), 1)
                nn.init.constant_(torch.as_tensor(m.bias), 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(torch.as_tensor(m.bias), 0)

    def _bottleneck_reduce(self) -> Sequence[int]:
        """
        Finds number of feature maps for the bottleneck layers.
        """
        reduced_features = []
        for n in self.n_features:
            reduced_features.append(n // self.bottleneck_reduction)

        return reduced_features


[docs] class ResNet(GeneralResNet): """ ResNet, based on `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`_. Adapted from :py:class:`MONAI's implementation <monai.networks.nets.ResNet>`. The user can customize the number of residual blocks, the number of downsampling blocks, the number of channels in each block, as well as other parameters like the type of residual block used. ResNet is a fully convolutional network that can work with an input of any size, provided that it is large enough not to be reduced to a 1-pixel image (before the adaptative average pooling). Works with 2D or 3D images (with additional batch and channel dimensions). Parameters ---------- spatial_dims : int Number of spatial dimensions of the input image. in_channels : int Number of channels in the input image. num_outputs : Optional[int] Number of output variables after the last linear layer. If ``None``, the feature map before the last fully connected layer will be returned. block_type : Union[str, ResNetBlockType], default="basic" Type of residual block. Either ``basic`` or ``bottleneck``. Default to ``basic``, as in ``ResNet-18``. n_res_blocks : Sequence[int], default=(2, 2, 2, 2) Number of residual block in each ResNet layer. A ResNet layer refers here to a set of residual blocks between two downsamplings. The length of ``n_res_blocks`` thus determines the number of ResNet layers. Default to ``(2, 2, 2, 2)``, as in ``ResNet-18``. n_features : Sequence[int], default=(64, 128, 256, 512) Number of output feature maps for each ResNet layer. The length of ``n_features`` must be equal to the length of ``n_res_blocks``. All elements of ``n_features`` must be divisible by ``bottleneck_reduction``.\n Default to ``(64, 128, 256, 512)``, as in ``ResNet-18``. init_conv_size : Union[Sequence[int], int], default=7 Kernel size for the first convolution. If ``tuple``, it will be understood as the values for each dimension. Default to ``7``, as in the original paper. init_conv_stride : Union[Sequence[int], int], default=2 Stride for the first convolution. If ``tuple``, it will be understood as the values for each dimension. Default to ``2``, as in the original paper. bottleneck_reduction : int, default=4 If ``block_type="bottleneck"``, ``bottleneck_reduction`` determines the reduction factor for the number of feature maps in bottleneck layers (1x1 convolutions). Default to ``4``, as in the original paper. act : ActivationParameters, default=("relu", {"inplace": True}) The activation function used after a convolutional layer, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them.\n Default is ``relu``, as in the original paper. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network. Must be passed in the same way as ``act``. If ``None``, no last activation will be applied. Raises ------ ValueError If ``len(n_features)!=len(n_res_blocks)``. ValueError If some elements of ``n_features`` are not divisible by ``bottleneck_reduction``. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.SEResNet` Examples -------- .. code-block:: >>> ResNet( spatial_dims=2, in_channels=1, num_outputs=2, block_type="bottleneck", bottleneck_reduction=4, n_features=(8, 16), n_res_blocks=(2, 2), output_act="softmax", init_conv_size=5, ) ResNet( (conv0): Conv2d(1, 2, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False) (norm0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act0): ReLU(inplace=True) (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): ResNetBottleneck( (conv1): Conv2d(2, 2, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): ReLU(inplace=True) (conv2): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (conv3): Conv2d(2, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(2, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (act3): ReLU(inplace=True) ) (1): ResNetBottleneck( (conv1): Conv2d(8, 2, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): ReLU(inplace=True) (conv2): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (conv3): Conv2d(2, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act3): ReLU(inplace=True) ) ) (layer2): Sequential( (0): ResNetBottleneck( (conv1): Conv2d(8, 4, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): ReLU(inplace=True) (conv2): Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (norm2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (conv3): Conv2d(4, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(8, 16, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (act3): ReLU(inplace=True) ) (1): ResNetBottleneck( (conv1): Conv2d(16, 4, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): ReLU(inplace=True) (conv2): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (norm2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): ReLU(inplace=True) (conv3): Conv2d(4, 16, kernel_size=(1, 1), stride=(1, 1), bias=False) (norm3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act3): ReLU(inplace=True) ) ) (fc): Sequential( (pool): AdaptiveAvgPool2d(output_size=(1, 1)) (flatten): Flatten(start_dim=1, end_dim=-1) (out): Linear(in_features=16, out_features=2, bias=True) (output_act): Softmax(dim=None) ) ) """ def __init__( self, spatial_dims: int, in_channels: int, num_outputs: Optional[int], block_type: Union[str, ResNetBlockType] = ResNetBlockType.BASIC, n_res_blocks: Sequence[int] = (2, 2, 2, 2), n_features: Sequence[int] = (64, 128, 256, 512), init_conv_size: Union[Sequence[int], int] = 7, init_conv_stride: Union[Sequence[int], int] = 2, bottleneck_reduction: int = 4, act: ActivationParameters = ("relu", {"inplace": True}), output_act: Optional[ActivationParameters] = None, ) -> None: config = ResNetConfig( spatial_dims=spatial_dims, in_channels=in_channels, num_outputs=num_outputs, block_type=block_type, n_res_blocks=n_res_blocks, n_features=n_features, init_conv_size=init_conv_size, init_conv_stride=init_conv_stride, bottleneck_reduction=bottleneck_reduction, act=act, output_act=output_act, ) super().__init__( se_reduction=None, **config.to_raw_dict(), ) def _load_weights(self, url: str) -> None: """To load weights from torchvision.""" fc_layers = deepcopy(self.fc) self.fc = None pretrained_dict = load_state_dict_from_url(url, progress=True) self.load_state_dict(_state_dict_adapter(pretrained_dict)) self.fc = fc_layers
[docs] class ResNet18(ResNet): """ ResNet-18, from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__. Only the last fully connected layer will be changed to match ``num_outputs``. The user can use the pretrained models from ``torchvision``. Note that the last fully connected layer will not use pretrained weights, as it is task specific. .. warning:: Only works with **2D images with 3 channels**. Parameters ---------- num_outputs : Optional[int] Number of output variables after the last linear layer. If ``None``, the feature map before the last fully connected layer will be returned. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary. If ``None``, no activation will be used.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them. pretrained : bool, default=False Whether to use pretrained weights. The pretrained weights used are the default ones from :py:func:`torchvision.models.resnet18`. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.ResNet` """ def __init__( self, num_outputs: Optional[int], output_act: Optional[ActivationParameters] = None, pretrained: bool = False, ) -> None: config = ResNet18Config( num_outputs=num_outputs, output_act=output_act, pretrained=pretrained ) super().__init__( spatial_dims=2, in_channels=3, num_outputs=config.num_outputs, n_res_blocks=(2, 2, 2, 2), block_type=ResNetBlockType.BASIC, n_features=(64, 128, 256, 512), output_act=config.output_act, ) if config.pretrained: self._load_weights(ResNet18_Weights.DEFAULT.url)
[docs] class ResNet34(ResNet): """ ResNet-34, from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__. Only the last fully connected layer will be changed to match ``num_outputs``. The user can use the pretrained models from ``torchvision``. Note that the last fully connected layer will not use pretrained weights, as it is task specific. .. warning:: Only works with **2D images with 3 channels**. Parameters ---------- num_outputs : Optional[int] Number of output variables after the last linear layer. If ``None``, the feature map before the last fully connected layer will be returned. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary. If ``None``, no activation will be used.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them. pretrained : bool, default=False Whether to use pretrained weights. The pretrained weights used are the default ones from :py:func:`torchvision.models.resnet34`. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.ResNet` """ def __init__( self, num_outputs: Optional[int], output_act: Optional[ActivationParameters] = None, pretrained: bool = False, ) -> None: config = ResNet34Config( num_outputs=num_outputs, output_act=output_act, pretrained=pretrained ) super().__init__( spatial_dims=2, in_channels=3, num_outputs=config.num_outputs, n_res_blocks=(3, 4, 6, 3), block_type=ResNetBlockType.BASIC, n_features=(64, 128, 256, 512), output_act=config.output_act, ) if config.pretrained: self._load_weights(ResNet34_Weights.DEFAULT.url)
[docs] class ResNet50(ResNet): """ ResNet-50, from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__. Only the last fully connected layer will be changed to match ``num_outputs``. The user can use the pretrained models from ``torchvision``. Note that the last fully connected layer will not use pretrained weights, as it is task specific. .. warning:: Only works with **2D images with 3 channels**. Parameters ---------- num_outputs : Optional[int] Number of output variables after the last linear layer. If ``None``, the feature map before the last fully connected layer will be returned. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary. If ``None``, no activation will be used.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them. pretrained : bool, default=False Whether to use pretrained weights. The pretrained weights used are the default ones from :py:func:`torchvision.models.resnet50`. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.ResNet` """ def __init__( self, num_outputs: Optional[int], output_act: Optional[ActivationParameters] = None, pretrained: bool = False, ) -> None: config = ResNet50Config( num_outputs=num_outputs, output_act=output_act, pretrained=pretrained ) super().__init__( spatial_dims=2, in_channels=3, num_outputs=config.num_outputs, n_res_blocks=(3, 4, 6, 3), block_type=ResNetBlockType.BOTTLENECK, n_features=(256, 512, 1024, 2048), output_act=config.output_act, ) if config.pretrained: self._load_weights(ResNet50_Weights.DEFAULT.url)
[docs] class ResNet101(ResNet): """ ResNet-101, from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__. Only the last fully connected layer will be changed to match ``num_outputs``. The user can use the pretrained models from ``torchvision``. Note that the last fully connected layer will not use pretrained weights, as it is task specific. .. warning:: Only works with **2D images with 3 channels**. Parameters ---------- num_outputs : Optional[int] Number of output variables after the last linear layer. If ``None``, the feature map before the last fully connected layer will be returned. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary. If ``None``, no activation will be used.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them. pretrained : bool, default=False Whether to use pretrained weights. The pretrained weights used are the default ones from :py:func:`torchvision.models.resnet101`. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.ResNet` """ def __init__( self, num_outputs: Optional[int], output_act: Optional[ActivationParameters] = None, pretrained: bool = False, ) -> None: config = ResNet101Config( num_outputs=num_outputs, output_act=output_act, pretrained=pretrained ) super().__init__( spatial_dims=2, in_channels=3, num_outputs=config.num_outputs, n_res_blocks=(3, 4, 23, 3), block_type=ResNetBlockType.BOTTLENECK, n_features=(256, 512, 1024, 2048), output_act=config.output_act, ) if config.pretrained: self._load_weights(ResNet101_Weights.DEFAULT.url)
[docs] class ResNet152(ResNet): """ ResNet-152, from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__. Only the last fully connected layer will be changed to match ``num_outputs``. The user can use the pretrained models from ``torchvision``. Note that the last fully connected layer will not use pretrained weights, as it is task specific. .. warning:: Only works with **2D images with 3 channels**. Parameters ---------- num_outputs : Optional[int] Number of output variables after the last linear layer. If ``None``, the feature map before the last fully connected layer will be returned. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary. If ``None``, no activation will be used.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them. pretrained : bool, default=False Whether to use pretrained weights. The pretrained weights used are the default ones from :py:func:`torchvision.models.resnet152`. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.ResNet` """ def __init__( self, num_outputs: Optional[int], output_act: Optional[ActivationParameters] = None, pretrained: bool = False, ) -> None: config = ResNet152Config( num_outputs=num_outputs, output_act=output_act, pretrained=pretrained ) super().__init__( spatial_dims=2, in_channels=3, num_outputs=config.num_outputs, n_res_blocks=(3, 8, 36, 3), block_type=ResNetBlockType.BOTTLENECK, n_features=(256, 512, 1024, 2048), output_act=config.output_act, ) if config.pretrained: self._load_weights(ResNet152_Weights.DEFAULT.url)
RES_NET_DEFAULTS = get_defaults_from(ResNet) RES_NET_18_DEFAULTS = get_defaults_from(ResNet18) RES_NET_34_DEFAULTS = get_defaults_from(ResNet34) RES_NET_50_DEFAULTS = get_defaults_from(ResNet50) RES_NET_101_DEFAULTS = get_defaults_from(ResNet101) RES_NET_152_DEFAULTS = get_defaults_from(ResNet152)
[docs] class ResNetConfig(NetworkConfig, _SpatialDimsConfig): """ Config class for :py:class:`clinicadl.networks.nn.ResNet`. """ spatial_dims: PositiveInt in_channels: PositiveInt num_outputs: Optional[PositiveInt] block_type: ResNetBlockType = RES_NET_DEFAULTS["block_type"] n_res_blocks: Sequence[PositiveInt] = RES_NET_DEFAULTS["n_res_blocks"] n_features: Sequence[PositiveInt] = RES_NET_DEFAULTS["n_features"] init_conv_size: Union[Sequence[PositiveInt], PositiveInt] = RES_NET_DEFAULTS[ "init_conv_size" ] init_conv_stride: Union[Sequence[PositiveInt], PositiveInt] = RES_NET_DEFAULTS[ "init_conv_stride" ] bottleneck_reduction: PositiveInt = RES_NET_DEFAULTS["bottleneck_reduction"] act: ActivationParameters = RES_NET_DEFAULTS["act"] output_act: Optional[ActivationParameters] = RES_NET_DEFAULTS["output_act"] @model_validator(mode="after") def make_checks(self): self._check_res_blocks(self.n_res_blocks, self.n_features) if self.block_type == ResNetBlockType.BOTTLENECK: self._check_bottleneck_reduction(self.n_features, self.bottleneck_reduction) self.__dict__["init_conv_size"] = ensure_tuple( self.init_conv_size, self.spatial_dims, "init_conv_size" ) self.__dict__["init_conv_stride"] = ensure_tuple( self.init_conv_stride, self.spatial_dims, "init_conv_stride" ) return self @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return ResNet @staticmethod def _check_bottleneck_reduction( n_features: Sequence[int], bottleneck_reduction: int ) -> Sequence[int]: """ Checks bottleneck_reduction. """ for n in n_features: if n % bottleneck_reduction != 0: raise ValueError( "All elements of n_features must be divisible by bottleneck_reduction. " f"Got {n} in n_features and bottleneck_reduction={bottleneck_reduction}" ) @staticmethod def _check_res_blocks( n_res_blocks: Sequence[int], n_features: Sequence[int] ) -> None: """ Checks consistency between `n_res_blocks` and `n_features`. """ if len(n_features) != len(n_res_blocks): raise ValueError( f"n_features and n_res_blocks must have the same length, got n_features={n_features} " f"and n_res_blocks={n_res_blocks}" )
[docs] class ResNet18Config(NetworkConfig): """ Config class for :py:class:`clinicadl.networks.nn.ResNet18`. """ num_outputs: Optional[PositiveInt] output_act: Optional[ActivationParameters] = RES_NET_18_DEFAULTS["output_act"] pretrained: bool = RES_NET_18_DEFAULTS["pretrained"] @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return ResNet18
[docs] class ResNet34Config(NetworkConfig): """ Config class for :py:class:`clinicadl.networks.nn.ResNet34`. """ num_outputs: Optional[PositiveInt] output_act: Optional[ActivationParameters] = RES_NET_34_DEFAULTS["output_act"] pretrained: bool = RES_NET_34_DEFAULTS["pretrained"] @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return ResNet34
[docs] class ResNet50Config(NetworkConfig): """ Config class for :py:class:`clinicadl.networks.nn.ResNet50`. """ num_outputs: Optional[PositiveInt] output_act: Optional[ActivationParameters] = RES_NET_50_DEFAULTS["output_act"] pretrained: bool = RES_NET_50_DEFAULTS["pretrained"] @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return ResNet50
[docs] class ResNet101Config(NetworkConfig): """ Config class for :py:class:`clinicadl.networks.nn.ResNet101`. """ num_outputs: Optional[PositiveInt] output_act: Optional[ActivationParameters] = RES_NET_101_DEFAULTS["output_act"] pretrained: bool = RES_NET_101_DEFAULTS["pretrained"] @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return ResNet101
[docs] class ResNet152Config(NetworkConfig): """ Config class for :py:class:`clinicadl.networks.nn.ResNet152`. """ num_outputs: Optional[PositiveInt] output_act: Optional[ActivationParameters] = RES_NET_152_DEFAULTS["output_act"] pretrained: bool = RES_NET_152_DEFAULTS["pretrained"] @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return ResNet152
def _state_dict_adapter(state_dict: Mapping[str, Any]) -> Mapping[str, Any]: """ A mapping between torchvision's layer names and ours. """ state_dict = {k: v for k, v in state_dict.items() if "fc" not in k} mappings = [ (r"(?<!\.)conv1", "conv0"), (r"(?<!\.)bn1", "norm0"), ("bn", "norm"), ] for key in list(state_dict.keys()): new_key = key for transform in mappings: new_key = re.sub(transform[0], transform[1], new_key) state_dict[new_key] = state_dict.pop(key) return state_dict