Source code for clinicadl.networks.nn.unet

from __future__ import annotations

from typing import Optional, Sequence

import torch
import torch.nn as nn
from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers.utils import get_act_layer
from pydantic import NonNegativeFloat, PositiveInt, field_validator

from clinicadl.networks.nn.layers.utils import ActivationParameters
from clinicadl.utils.factories import get_defaults_from

from .layers.unet import ConvBlock, DownBlock, UpBlock
from .layers.utils import ActFunction, ActivationParameters
from .utils.config import NetworkConfig, _DropoutConfig, _SpatialDimsConfig


[docs] class UNet(nn.Module): """ UNet, based on `U-Net: Convolutional Networks for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>`_. The user can customize the number of encoding blocks, the number of channels in each block, as well as other parameters like the activation function. Works with 2D or 3D images (with additional batch and channel dimensions). .. warning:: ``UNet`` works only with images whose dimensions are high enough powers of 2. More precisely, if ``n`` is the number of max pooling operation in your ``UNet`` (which is equal to ``len(channels)-1``), the image must have :math:`2^{k}` pixels in each dimension, with :math:`k \\geq n` (e.g. shape (:math:`2^{n}`, :math:`2^{n+3}`, :math:`2^{n+1}`) for a 3D image). .. note:: The implementation proposed here is not exactly the one described in the original paper. Padding is added to convolutions so that the feature maps keep a constant size, batch normalization is used, and "up-conv" layers are here made with a :py:class:`torch.nn.Upsample` layer followed by a 3x3 convolution. Parameters ---------- spatial_dims : int Number of spatial dimensions of the input image. in_channels : int Number of channels in the input image. out_channels : int Number of output channels. channels : Sequence[int], default=(64, 128, 256, 512, 1024) Number of channels in each UNet block. Thus, this parameter also controls the number of UNet blocks (equal to the length of the sequence). The length ``channels`` should be no less than ``2``.\n Default to ``(64, 128, 256, 512, 1024)``, as in the original paper. act : ActivationParameters, default="relu" The activation function used, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them.\n Default is ``relu``, as in the original paper. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network. Must be passed in the same way as ``act``. If ``None``, no last activation will be applied. dropout : Optional[float], default=None Dropout ratio. If ``None``, no dropout. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. Examples -------- .. code-block:: python # a UNet with 1 downsampling (instead of 4 in the original paper) >>> UNet( spatial_dims=2, in_channels=1, out_channels=2, channels=(4, 8), act="elu", output_act=("softmax", {"dim": 1}), dropout=0.1, ) UNet( (doubleconv): ConvBlock( (0): Convolution( (conv): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (adn): ADN( (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (D): Dropout(p=0.1, inplace=False) (A): ELU(alpha=1.0) ) ) (1): Convolution( (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (adn): ADN( (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (D): Dropout(p=0.1, inplace=False) (A): ELU(alpha=1.0) ) ) ) (down1): DownBlock( (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (doubleconv): ConvBlock( (0): Convolution( (conv): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (adn): ADN( (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (D): Dropout(p=0.1, inplace=False) (A): ELU(alpha=1.0) ) ) (1): Convolution( (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (adn): ADN( (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (D): Dropout(p=0.1, inplace=False) (A): ELU(alpha=1.0) ) ) ) ) (up1): UpBlock( (upsample): UpSample( (0): Upsample(scale_factor=2.0, mode='nearest') (1): Convolution( (conv): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (adn): ADN( (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (D): Dropout(p=0.1, inplace=False) (A): ELU(alpha=1.0) ) ) ) (doubleconv): ConvBlock( (0): Convolution( (conv): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (adn): ADN( (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (D): Dropout(p=0.1, inplace=False) (A): ELU(alpha=1.0) ) ) (1): Convolution( (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (adn): ADN( (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (D): Dropout(p=0.1, inplace=False) (A): ELU(alpha=1.0) ) ) ) ) (reduce_channels): Convolution( (conv): Conv2d(4, 2, kernel_size=(1, 1), stride=(1, 1)) ) (output_act): Softmax(dim=1) ) """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int] = (64, 128, 256, 512, 1024), act: ActivationParameters = ActFunction.RELU, output_act: Optional[ActivationParameters] = None, dropout: Optional[float] = None, ): super().__init__() self.config = UNetConfig( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, channels=channels, act=act, output_act=output_act, dropout=dropout, ) self.doubleconv = ConvBlock( spatial_dims=self.config.spatial_dims, in_channels=self.config.in_channels, out_channels=self.config.channels[0], act=self.config.act, dropout=self.config.dropout, ) self._build_encoder() self._build_decoder() self.reduce_channels = Convolution( spatial_dims=self.config.spatial_dims, in_channels=self.config.channels[0], out_channels=self.config.out_channels, kernel_size=1, strides=1, padding=0, conv_only=True, ) self.output_act = ( get_act_layer(self.config.output_act) if self.config.output_act else None )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x_history = [self.doubleconv(x)] for i in range(1, len(self.config.channels)): x = self.get_submodule(f"down{i}")(x_history[-1]) x_history.append(x) x_history.pop() # the output of bottelneck is not used as a gating signal for i in range(len(self.config.channels) - 1, 0, -1): x = self.get_submodule(f"up{i}")(x, skip=x_history.pop()) out = self.reduce_channels(x) if self.output_act is not None: out = self.output_act(out) return out
def _build_encoder(self) -> None: for i in range(1, len(self.config.channels)): self.add_module( f"down{i}", DownBlock( spatial_dims=self.config.spatial_dims, in_channels=self.config.channels[i - 1], out_channels=self.config.channels[i], act=self.config.act, dropout=self.config.dropout, ), ) def _build_decoder(self): for i in range(len(self.config.channels) - 1, 0, -1): self.add_module( f"up{i}", self._decoding_block( spatial_dims=self.config.spatial_dims, in_channels=self.config.channels[i], out_channels=self.config.channels[i - 1], act=self.config.act, dropout=self.config.dropout, ), ) @property def _decoding_block(self) -> type[nn.Module]: return UpBlock
UNET_DEFAULTS = get_defaults_from(UNet)
[docs] class UNetConfig( NetworkConfig, _SpatialDimsConfig, _DropoutConfig, ): """ Config class for :py:class:`clinicadl.networks.nn.UNet`. """ spatial_dims: PositiveInt in_channels: PositiveInt out_channels: PositiveInt channels: Sequence[PositiveInt] = UNET_DEFAULTS["channels"] act: ActivationParameters = UNET_DEFAULTS["act"] output_act: Optional[ActivationParameters] = UNET_DEFAULTS["output_act"] dropout: Optional[NonNegativeFloat] = UNET_DEFAULTS["dropout"] @field_validator("channels") @classmethod def _channels_validator(cls, v): if isinstance(v, Sequence) and len(v) < 2: raise ValueError(f"length of channels must be no less than 2. Got {v}") return v @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return UNet