Source code for clinicadl.networks.nn.autoencoder

from copy import deepcopy
from typing import Any, Literal, Optional, Sequence, Tuple, Union

import numpy as np
import torch.nn as nn
from pydantic import PositiveInt, field_validator, model_validator

from clinicadl.utils.factories import get_defaults_from

from .cnn import CNN
from .conv_decoder import ConvDecoderOptions
from .conv_encoder import ConvEncoder, ConvEncoderOptions
from .generator import Generator
from .layers.utils import (
    ActivationParameters,
    PoolingLayer,
    SingleLayerPoolingParameters,
    SingleLayerUnpoolingParameters,
    UnpoolingLayer,
    UnpoolingMode,
)
from .mlp import MLPOptions
from .utils import (
    calculate_conv_out_shape,
    calculate_convtranspose_out_shape,
    calculate_pool_out_shape,
)
from .utils.config import NetworkConfig, _InShapeConfig


[docs] class AutoEncoder(nn.Sequential): """ An AutoEncoder with convolutional and fully connected layers. The user must pass the arguments to build an encoder, from its convolutional and fully connected parts, and the decoder will be automatically built by taking the symmetrical network. More precisely, to build the decoder, the order of the encoding layers is reverted, convolutions are replaced by transposed convolutions, and pooling layers are replaced by either upsampling or transposed convolution layers. An ``AutoEncoder`` is an aggregation of a :py:class:`~clinicadl.networks.nn.CNN` and a :py:class:`~clinicadl.networks.nn.Generator`. Works with 2D or 3D images (with additional batch and channel dimensions). .. note:: Please note that the order of Activation, Dropout and Normalization, defined with the argument ``adn_ordering`` in ``conv_args``, is the same for the encoder and the decoder. Parameters ---------- in_shape : Sequence[int] Dimensions of the input tensor (without batch dimension). latent_size : int Size of the latent vector. conv_args : dict[str, Any] The arguments for the convolutional part. The arguments are those accepted by :py:class:`~clinicadl.networks.nn.ConvEncoder`, except ``spatial_dims`` and ``in_channels`` that are specified here via ``in_shape``. So, the only **mandatory argument is** ``channels``. mlp_args : Optional[dict[str, Any]], default=None The arguments for the MLP part. The arguments are those accepted by :py:class:`~clinicadl.networks.nn.MLP`, except ``num_inputs`` that is inferred from the output of the convolutional part, and ``num_outputs`` that is equal to ``latent_size`` here. So, the only **mandatory argument is** ``hidden_dims``.\n If ``None``, the MLP part will be reduced to a single linear layer. out_channels : Optional[int], default=None Number of output channels. If ``None``, the output will have the same number of channels as the input. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary. If ``None``, no activation will be used.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them. unpooling_mode : Union[str, UnpoolingMode], default=UnpoolingMode.NEAREST Type of unpooling. Can be any value in {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"`` or ``"convtranspose"``}: - ``nearest``: unpooling is performed by upsampling with the `nearest` algorithm (see :py:class:`torch.nn.Upsample`); - ``linear``: unpooling is performed by upsampling with the `linear` algorithm. Only works with 1D images (excluding the channel dimension); - ``bilinear``: unpooling is performed by upsampling with the `bilinear` algorithm. Only works with 2D images; - ``bicubic``: unpooling is performed by upsampling with the `bicubic` algorithm. Only works with 2D images; - ``trilinear``: unpooling is performed by upsampling with the `trilinear` algorithm. Only works with 3D images; - ``convtranspose``: unpooling is performed with a transposed convolution (see :py:class:`torch.nn.ConvTranspose3d`), whose parameters (kernel size, stride, etc.) are computed to reverse the pooling operation. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.CNN` :py:class:`~clinicadl.networks.nn.Generator` Examples -------- .. code-block:: python >>> AutoEncoder( in_shape=(1, 16, 16), latent_size=8, conv_args={ "channels": [2, 4], "pooling_indices": [0], "pooling": ("avg", {"kernel_size": 2}), }, mlp_args={"hidden_dims": [32], "output_act": "relu"}, out_channels=2, output_act="sigmoid", unpooling_mode="bilinear", ) AutoEncoder( (encoder): CNN( (convolutions): ConvEncoder( (layer0): Convolution( (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) (adn): ADN( (N): InstanceNorm2d(2, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False) (A): PReLU(num_parameters=1) ) ) (pool0): AvgPool2d(kernel_size=2, stride=2, padding=0) (layer1): Convolution( (conv): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1)) ) ) (mlp): MLP( (flatten): Flatten(start_dim=1, end_dim=-1) (hidden0): Sequential( (linear): Linear(in_features=100, out_features=32, bias=True) (adn): ADN( (N): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (A): PReLU(num_parameters=1) ) ) (output): Sequential( (linear): Linear(in_features=32, out_features=8, bias=True) (output_act): ReLU() ) ) ) (decoder): Generator( (mlp): MLP( (flatten): Flatten(start_dim=1, end_dim=-1) (hidden0): Sequential( (linear): Linear(in_features=8, out_features=32, bias=True) (adn): ADN( (N): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (A): PReLU(num_parameters=1) ) ) (output): Sequential( (linear): Linear(in_features=32, out_features=100, bias=True) (output_act): ReLU() ) ) (reshape): Reshape() (convolutions): ConvDecoder( (layer0): Convolution( (conv): ConvTranspose2d(4, 4, kernel_size=(3, 3), stride=(1, 1)) (adn): ADN( (N): InstanceNorm2d(4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False) (A): PReLU(num_parameters=1) ) ) (unpool0): Upsample(size=(14, 14), mode=<UpsamplingMode.BILINEAR: 'bilinear'>) (layer1): Convolution( (conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1)) ) (output_act): Sigmoid() ) ) ) """ def __init__( self, in_shape: Sequence[int], latent_size: int, conv_args: dict[str, Any], mlp_args: Optional[dict[str, Any]] = None, out_channels: Optional[int] = None, output_act: Optional[ActivationParameters] = None, unpooling_mode: Union[str, UnpoolingMode] = UnpoolingMode.NEAREST, ) -> None: super().__init__() self.config = AutoEncoderConfig( in_shape=in_shape, latent_size=latent_size, conv_args=conv_args, mlp_args=mlp_args, out_channels=out_channels, output_act=output_act, unpooling_mode=unpooling_mode, ) self.encoder = CNN( in_shape=self.config.in_shape, num_outputs=self.config.latent_size, conv_args=self.config.conv_args.to_raw_dict(), mlp_args=self.config.mlp_args.to_raw_dict(), ) inter_channels = ( self.config.conv_args.channels[-1] if len(self.config.conv_args.channels) > 0 else self.config.in_shape[0] ) inter_shape = (inter_channels, *self.encoder.convolutions._final_size) self.decoder = Generator( latent_size=self.config.latent_size, start_shape=inter_shape, conv_args=self._invert_conv_args( self.config.conv_args, self.encoder.convolutions ).to_raw_dict(), mlp_args=self._invert_mlp_args(self.config.mlp_args).to_raw_dict(), ) @classmethod def _invert_mlp_args( cls, args: MLPOptions, ) -> MLPOptions: """ Inverts arguments passed for the MLP part of the encoder, to get the MLP part of the decoder. """ args = deepcopy(args) args.hidden_dims = cls._invert_list_arg(args.hidden_dims) return args def _invert_conv_args( self, encoder_args: ConvEncoderOptions, conv: ConvEncoder ) -> ConvDecoderOptions: """ Inverts arguments passed for the convolutional part of the encoder, to get the convolutional part of the decoder. """ channels = ( self._invert_list_arg(encoder_args.channels[:-1]) + [self.config.out_channels] if len(encoder_args.channels) > 0 else [] ) kernel_size = self._invert_list_arg(encoder_args.kernel_size) stride = self._invert_list_arg(encoder_args.stride) dilation = self._invert_list_arg(encoder_args.dilation) padding, output_padding = self._get_paddings_list(conv) unpooling_indices = list( ( len(encoder_args.channels) - np.array(encoder_args.pooling_indices) - 2 ).astype(int) ) unpooling = [] sizes_before_pooling = [ size for size, (layer_name, _) in zip(conv._size_details, conv.named_children()) if "pool" in layer_name ] for size, pooling in zip( sizes_before_pooling[::-1], encoder_args.pooling[::-1] ): unpooling.append(self._invert_pooling_layer(size, pooling)) decoder_args = ConvDecoderOptions( channels=channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, dilation=dilation, unpooling_indices=unpooling_indices, unpooling=unpooling, act=encoder_args.act, norm=encoder_args.norm, output_act=self.config.output_act, dropout=encoder_args.dropout, bias=encoder_args.bias, adn_ordering=encoder_args.adn_ordering, ) return decoder_args @staticmethod def _invert_list_arg(arg: Union[Any, list[Any]]) -> Union[Any, list[Any]]: """ Reverses lists. """ return list(arg[::-1]) if isinstance(arg, Sequence) else arg def _invert_pooling_layer( self, size_before_pool: Sequence[int], pooling: SingleLayerPoolingParameters, ) -> SingleLayerUnpoolingParameters: """ Gets the unpooling layer. """ if self.config.unpooling_mode == UnpoolingMode.CONV_TRANS: return ( UnpoolingLayer.CONV_TRANS, self._invert_pooling_with_convtranspose(size_before_pool, pooling), ) else: return ( UnpoolingLayer.UPSAMPLE, {"size": size_before_pool, "mode": self.config.unpooling_mode}, ) @classmethod def _invert_pooling_with_convtranspose( cls, size_before_pool: Sequence[int], pooling: SingleLayerPoolingParameters, ) -> dict[str, Any]: """ Computes the arguments of the transposed convolution, based on the pooling layer. """ pooling_mode, pooling_args = pooling if ( pooling_mode == PoolingLayer.ADAPT_AVG or pooling_mode == PoolingLayer.ADAPT_MAX ): input_size_np = np.array(size_before_pool) output_size_np = np.array(pooling_args["output_size"]) stride_np = input_size_np // output_size_np # adaptive pooling formulas kernel_size_np = ( input_size_np - (output_size_np - 1) * stride_np ) # adaptive pooling formulas args = { "kernel_size": tuple(int(k) for k in kernel_size_np), "stride": tuple(int(s) for s in stride_np), } padding, output_padding = cls._find_convtranspose_paddings( pooling_mode, size_before_pool, output_size=pooling_args["output_size"], **args, ) elif pooling_mode == PoolingLayer.MAX or pooling_mode == PoolingLayer.AVG: if "stride" not in pooling_args: pooling_args["stride"] = pooling_args["kernel_size"] args = { arg: value for arg, value in pooling_args.items() if arg in ["kernel_size", "stride", "padding", "dilation"] } padding, output_padding = cls._find_convtranspose_paddings( pooling_mode, size_before_pool, **pooling_args, ) args["padding"] = padding # pylint: disable=possibly-used-before-assignment args["output_padding"] = output_padding # pylint: disable=possibly-used-before-assignment return args @classmethod def _get_paddings_list(cls, conv: ConvEncoder) -> list[tuple[int, ...]]: """ Finds output padding list. """ padding = [] output_padding = [] size_before_convs = [ size for size, (layer_name, _) in zip(conv._size_details, conv.named_children()) if "layer" in layer_name ] for size, k, s, p, d in zip( size_before_convs, conv.config.kernel_size, conv.config.stride, conv.config.padding, conv.config.dilation, ): p, out_p = cls._find_convtranspose_paddings( "conv", size, kernel_size=k, stride=s, padding=p, dilation=d ) padding.append(p) output_padding.append(out_p) return cls._invert_list_arg(padding), cls._invert_list_arg(output_padding) @classmethod def _find_convtranspose_paddings( cls, layer_type: Union[Literal["conv"], PoolingLayer], in_shape: Union[Sequence[int], int], padding: Union[Sequence[int], int] = 0, **kwargs, ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: """ Finds padding and output padding necessary to recover the right image size after a transposed convolution. """ if layer_type == "conv": layer_out_shape = calculate_conv_out_shape(in_shape, **kwargs) elif layer_type in list(PoolingLayer): layer_out_shape = calculate_pool_out_shape(layer_type, in_shape, **kwargs) convt_out_shape = calculate_convtranspose_out_shape(layer_out_shape, **kwargs) # pylint: disable=possibly-used-before-assignment output_padding = np.atleast_1d(in_shape) - np.atleast_1d(convt_out_shape) if ( output_padding < 0 ).any(): # can happen with ceil_mode=True for maxpool. Then, add some padding padding = np.atleast_1d(padding) * np.ones_like( output_padding ) # to have the same shape as output_padding padding[output_padding < 0] += np.maximum(np.abs(output_padding) // 2, 1)[ output_padding < 0 ] # //2 because 2*padding pixels are removed convt_out_shape = calculate_convtranspose_out_shape( layer_out_shape, padding=padding, **kwargs ) output_padding = np.atleast_1d(in_shape) - np.atleast_1d(convt_out_shape) padding = tuple(int(s) for s in padding) return padding, tuple(int(s) for s in output_padding)
AUTOENCODER_DEFAULTS = get_defaults_from(AutoEncoder)
[docs] class AutoEncoderConfig(NetworkConfig, _InShapeConfig): """ Config class for :py:class:`clinicadl.networks.nn.AutoEncoder`. """ in_shape: Sequence[PositiveInt] latent_size: PositiveInt conv_args: ConvEncoderOptions mlp_args: MLPOptions = AUTOENCODER_DEFAULTS["mlp_args"] out_channels: Optional[PositiveInt] = AUTOENCODER_DEFAULTS["out_channels"] output_act: Optional[ActivationParameters] = AUTOENCODER_DEFAULTS["output_act"] unpooling_mode: UnpoolingMode = AUTOENCODER_DEFAULTS["unpooling_mode"] @field_validator("mlp_args", mode="before") @classmethod def _handle_none_mlp_args(cls, v): """ To accept None value for 'mlp_args'. """ if v is None: return MLPOptions(hidden_dims=[]) return v @model_validator(mode="after") def _check_dim(self): _, *input_size = self.in_shape spatial_dims = len(input_size) self.conv_args._check_args_dim(spatial_dims) self._check_unpooling_mode(self.unpooling_mode, spatial_dims) self.__dict__["out_channels"] = ( self.out_channels if self.out_channels else self.in_shape[0] ) return self @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return AutoEncoder @staticmethod def _check_unpooling_mode(unpooling_mode: UnpoolingMode, dim: int) -> None: """ Checks consistency between data shape and unpooling mode. """ if unpooling_mode == UnpoolingMode.LINEAR and dim != 1: raise ValueError( f"unpooling mode `linear` only works with 1D data (spatial dimensions). " f"Got {dim}D data." ) elif unpooling_mode == UnpoolingMode.BILINEAR and dim != 2: raise ValueError( f"unpooling mode `bilinear` only works with 2D data (spatial dimensions). " f"Got {dim}D data." ) elif unpooling_mode == UnpoolingMode.BICUBIC and dim != 2: raise ValueError( f"unpooling mode `bicubic` only works with 2D data (spatial dimensions). " f"Got {dim}D data." ) elif unpooling_mode == UnpoolingMode.TRILINEAR and dim != 3: raise ValueError( f"unpooling mode `trilinear` only works with 3D data (spatial dimensions). " f"Got {dim}D data." )