Source code for clinicadl.networks.nn.conv_decoder

from typing import Callable, List, Optional, Sequence, Tuple

import torch.nn as nn
from monai.networks.blocks import Convolution
from monai.networks.layers.utils import get_act_layer
from pydantic import NonNegativeFloat, PositiveInt, model_validator

from clinicadl.utils.factories import get_defaults_from

from .conv_encoder import _BaseConvOptions
from .layers.unpool import get_unpool_layer
from .layers.utils import (
    ActFunction,
    ActivationParameters,
    ConvNormalizationParameters,
    ConvNormLayer,
    ConvParameters,
    SingleLayerUnpoolingParameters,
    UnpoolingLayer,
    UnpoolingParameters,
)
from .utils import (
    calculate_convtranspose_out_shape,
    calculate_unpool_out_shape,
    ensure_list_of_tuples,
)
from .utils.config import NetworkConfig, _SpatialDimsConfig


[docs] class ConvDecoder(nn.Sequential): """ Fully convolutional decoder network with transposed convolutions, unpooling, normalization, activation and dropout layers. It is the symmetric of :py:class:`~clinicadl.networks.nn.ConvEncoder`, where convolutions are replaced by transposed convolutions, and pooling layers by unpooling layers. Works with 2D or 3D images (with additional batch and channel dimensions). Parameters ---------- spatial_dims : int Number of spatial dimensions of the input image. in_channels : int Number of channels in the input image. channels : Sequence[int] Number of output channels of each transposed convolution. Thus, this parameter also controls the number of transposed convolutions (equal to the length of the sequence). kernel_size : ConvParameters, default=3 Kernel size of the transposed convolutions. Can be an ``int``, a ``tuple``, or a ``list``: - ``int``: the value will be used for all layers and all dimensions; - ``tuple`` (e.g. ``(3, 3, 2)``): it will be interpreted as the values for each dimension. These values will be used for all the layers; - ``list`` (e.g. ``[(3, 3, 2), 3]``): it will be interpreted as the kernel sizes for each layer. The length of the list must be equal to the number of transposed convolutions (i.e. ``len(channels)``). stride : ConvParameters, default=1 Stride of the transposed convolutions. Can be an ``int``, a ``tuple``, or a ``list``, and is passed in the same way as ``kernel_size``.\n padding : ConvParameters, default=0 Padding of the transposed convolutions. Can be an ``int``, a ``tuple``, or a ``list``, and is passed in the same way as ``kernel_size``.\n output_padding : ConvParameters, default=0 Output padding of the transposed convolutions. Can be an ``int``, a ``tuple``, or a ``list``, and is passed in the same way as ``kernel_size``.\n dilation : ConvParameters, default=1 Dilation factor of the transposed convolutions. Can be an ``int``, a ``tuple``, or a ``list``, and is passed in the same way as ``kernel_size``.\n unpooling : Optional[UnpoolingParameters], default=("upsample", {"scale_factor": 2}) The unpooling mode and the arguments of the unpooling layer, passed as ``(unpooling_mode, arguments)``, where ``arguments`` is a dictionary. If ``None``, no unpooling will be performed in the network.\n ``unpooling_mode`` can be either ``upsample`` or ``convtranspose``. Please refer to :py:class:`torch.nn.Upsample` or :py:class:`torch.nn.ConvTranspose3d` to know their arguments.\n If a ``list`` is passed, it will be understood as the unpooling for each unpooling layer. .. note:: No need to pass ``in_channels`` and ``out_channels`` for ``convtranspose``, because the unpooling layers are not intended to modify the number of channels here. unpooling_indices : Optional[Sequence[int]], default=None Indices of the transposed convolutions after which unpooling should be performed. If ``None``, no unpooling will be performed. An index equal to ``-1`` will be understood as a pooling layer before the first transposed convolution. act : Optional[ActivationParameters], default="prelu" The activation function used after a transposed convolution, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary. If ``None``, no activation will be used.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network. Must be passed in the same way as ``act``. If ``None``, no last activation will be applied. norm : Optional[ConvNormalizationParameters], default="instance" The normalization layer used after a transposed convolution, and optionally its arguments. Must be passed as ``norm_type`` or ``(norm_type, arguments)`` where ``arguments`` is a dictionary. If ``None``, no normalization will be performed.\n ``norm_type`` can be any value in {``"batch"``, ``"group"``, ``"instance"``, ``"syncbatch"``}. Please refer to :torch:`PyTorch normalization layers <nn.html#normalization-layers>` to know the arguments for each of them. .. note:: Please note that there's no need to pass the arguments ``num_channels`` and ``num_features`` of the normalization layer, as they are automatically inferred from the output of the previous layer in the network. dropout : Optional[float], default=None Dropout ratio. If ``None``, no dropout. bias : bool, default=True Whether to have a bias term in linear layers. adn_ordering : str, default="NDA" Order of operations Activation, Dropout and Normalization, after a linear layer (except the last one). **Cannot contain duplicated letters**. For example if ``"ND"`` is passed, Normalization and then Dropout will be performed (without Activation).\n .. note:: ADN will not be applied after the last linear layer. Raises ------ ValueError If a ``list`` is passed for ``kernel_size``, ``stride``, ``padding``, ``output_padding``, or ``dilation``, and the size of this list in not equal to ``len(channels)``. ValueError If indices in ``unpooling_indices`` are greater than ``len(channels)-1`` (``len(channels)-1`` being the index of the last transposed convolution). ValueError If a ``list`` is passed for ``unpooling``, and ``len(unpooling)!=len(unpooling_indices)``. ValueError If the activation or normalization layer requires a mandatory argument, which is not passed by the user (via a dictionary in ``act`` or ``norm``). See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.ConvEncoder` Examples -------- .. code-block:: python >>> ConvDecoder( in_channels=16, spatial_dims=2, channels=[8, 4, 1], kernel_size=(3, 5), stride=2, padding=[1, 0, 0], output_padding=[0, 0, (1, 2)], dilation=1, unpooling=[("upsample", {"scale_factor": 2}), ("upsample", {"size": (32, 32)})], unpooling_indices=[0, 1], act="elu", output_act="relu", norm=("batch", {"eps": 1e-05}), dropout=0.1, bias=True, adn_ordering="NDA", ) ConvDecoder( (layer0): Convolution( (conv): ConvTranspose2d(16, 8, kernel_size=(3, 5), stride=(2, 2), padding=(1, 1)) (adn): ADN( (N): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (D): Dropout(p=0.1, inplace=False) (A): ELU(alpha=1.0) ) ) (unpool0): Upsample(scale_factor=2.0, mode='nearest') (layer1): Convolution( (conv): ConvTranspose2d(8, 4, kernel_size=(3, 5), stride=(2, 2)) (adn): ADN( (N): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (D): Dropout(p=0.1, inplace=False) (A): ELU(alpha=1.0) ) ) (unpool1): Upsample(size=(32, 32), mode='nearest') (layer2): Convolution( (conv): ConvTranspose2d(4, 1, kernel_size=(3, 5), stride=(2, 2), output_padding=(1, 2)) ) (output_act): ReLU() ) """ def __init__( self, spatial_dims: int, in_channels: int, channels: Sequence[int], kernel_size: ConvParameters = 3, stride: ConvParameters = 1, padding: ConvParameters = 0, output_padding: ConvParameters = 0, dilation: ConvParameters = 1, unpooling: Optional[UnpoolingParameters] = ( UnpoolingLayer.UPSAMPLE, {"scale_factor": 2}, ), unpooling_indices: Optional[Sequence[int]] = None, act: Optional[ActivationParameters] = ActFunction.PRELU, output_act: Optional[ActivationParameters] = None, norm: Optional[ConvNormalizationParameters] = ConvNormLayer.INSTANCE, dropout: Optional[float] = None, bias: bool = True, adn_ordering: str = "NDA", _input_size: Optional[Sequence[int]] = None, ) -> None: super().__init__() self.config = ConvDecoderConfig( spatial_dims=spatial_dims, in_channels=in_channels, channels=channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, dilation=dilation, unpooling=unpooling, unpooling_indices=unpooling_indices, act=act, output_act=output_act, norm=norm, dropout=dropout, bias=bias, adn_ordering=adn_ordering, ) self._current_size = _input_size if _input_size else None self.n_layers = len(self.config.channels) n_unpoolings = 0 if self.config.unpooling and -1 in self.config.unpooling_indices: unpooling_layer = self._get_unpool_layer( self.config.unpooling[n_unpoolings], n_channels=self.config.in_channels ) self.add_module("init_unpool", unpooling_layer) n_unpoolings += 1 echannel = self.config.in_channels for i, (c, k, s, p, o_p, d) in enumerate( zip( self.config.channels, self.config.kernel_size, self.config.stride, self.config.padding, self.config.output_padding, self.config.dilation, ) ): conv_layer = self._get_convtranspose_layer( in_channels=echannel, out_channels=c, kernel_size=k, stride=s, padding=p, output_padding=o_p, dilation=d, is_last=(i == len(channels) - 1), ) self.add_module(f"layer{i}", conv_layer) echannel = c # use the output channel number as the input for the next loop if self.config.unpooling and i in self.config.unpooling_indices: unpooling_layer = self._get_unpool_layer( self.config.unpooling[n_unpoolings], n_channels=c ) self.add_module(f"unpool{i}", unpooling_layer) n_unpoolings += 1 self.output_act = get_act_layer(output_act) if output_act else None @property def _final_size(self): """ To know the size of an image at the end of the network. """ return self._current_size @_final_size.setter def _final_size(self, fct: Callable[[Tuple[int, ...]], Tuple[int, ...]]): """ Takes as input the function used to update the current image size. """ if self._current_size is not None: self._current_size = fct(self._current_size) def _get_convtranspose_layer( self, in_channels: int, out_channels: int, kernel_size: Tuple[int, ...], stride: Tuple[int, ...], padding: Tuple[int, ...], output_padding: Tuple[int, ...], dilation: Tuple[int, ...], is_last: bool, ) -> Convolution: """ Gets the parametrized TransposedConvolution-ADN block and updates the current output size. """ self._final_size = lambda size: calculate_convtranspose_out_shape( size, kernel_size, stride, padding, output_padding, dilation ) return Convolution( is_transposed=True, conv_only=is_last, spatial_dims=self.config.spatial_dims, in_channels=in_channels, out_channels=out_channels, strides=stride, kernel_size=kernel_size, padding=padding, output_padding=output_padding, dilation=dilation, act=self.config.act, norm=self.config.norm, dropout=self.config.dropout, bias=self.config.bias, adn_ordering=self.config.adn_ordering, ) def _get_unpool_layer( self, unpooling: SingleLayerUnpoolingParameters, n_channels: int ) -> nn.Module: """ Gets the parametrized unpooling layer and updates the current output size. """ unpool_layer = get_unpool_layer( unpooling, spatial_dims=self.config.spatial_dims, in_channels=n_channels, out_channels=n_channels, ) self._final_size = lambda size: calculate_unpool_out_shape( unpool_mode=unpooling[0], in_shape=size, **unpool_layer.__dict__, ) return unpool_layer
CONV_DECODER_DEFAULTS = get_defaults_from(ConvDecoder)
[docs] class ConvDecoderOptions(_BaseConvOptions): """ Config class for ``conv_args`` of :py:class:`~clinicadl.networks.config.GeneratorConfig`. """ channels: Sequence[PositiveInt] kernel_size: ConvParameters = CONV_DECODER_DEFAULTS["kernel_size"] stride: ConvParameters = CONV_DECODER_DEFAULTS["stride"] padding: ConvParameters = CONV_DECODER_DEFAULTS["padding"] output_padding: ConvParameters = CONV_DECODER_DEFAULTS["output_padding"] dilation: ConvParameters = CONV_DECODER_DEFAULTS["dilation"] unpooling: Optional[UnpoolingParameters] = CONV_DECODER_DEFAULTS["unpooling"] unpooling_indices: Optional[Sequence[int]] = CONV_DECODER_DEFAULTS[ "unpooling_indices" ] act: Optional[ActivationParameters] = CONV_DECODER_DEFAULTS["act"] output_act: Optional[ActivationParameters] = CONV_DECODER_DEFAULTS["output_act"] norm: Optional[ConvNormalizationParameters] = CONV_DECODER_DEFAULTS["norm"] dropout: Optional[NonNegativeFloat] = CONV_DECODER_DEFAULTS["dropout"] bias: bool = CONV_DECODER_DEFAULTS["bias"] adn_ordering: str = CONV_DECODER_DEFAULTS["adn_ordering"] @model_validator(mode="after") def check_unpooling(self): checked_indices = self._check_pool_indices(self.unpooling_indices) self.__dict__["unpooling"] = self._check_unpool_layers( self.unpooling, unpooling_indices=checked_indices ) return self def _check_args_dim(self, dim: int) -> None: super()._check_args_dim(dim) self.__dict__["output_padding"] = ensure_list_of_tuples( self.output_padding, dim, len(self.channels), "output_padding" ) @classmethod def _check_unpool_layers( cls, unpooling: UnpoolingParameters, unpooling_indices: Sequence[int] ) -> List[SingleLayerUnpoolingParameters]: """ Checks argument unpooling. """ if isinstance(unpooling, list): if len(unpooling) != len(unpooling_indices): raise ValueError( "If you pass a list for unpooling, the size of that list must match " f"the size of unpooling_indices. Got: unpooling={unpooling} and " f"unpooling_indices={unpooling_indices}" ) elif isinstance(unpooling, tuple): unpooling = [unpooling] * len(unpooling_indices) return unpooling
[docs] class ConvDecoderConfig(NetworkConfig, ConvDecoderOptions, _SpatialDimsConfig): """ Config class for :py:class:`clinicadl.networks.nn.ConvDecoder`. """ spatial_dims: PositiveInt in_channels: PositiveInt @model_validator(mode="after") def check_dim(self): self._check_args_dim(self.spatial_dims) return self @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return ConvDecoder