Source code for clinicadl.networks.nn.cnn

from typing import Any, Optional, Sequence

import numpy as np
import torch.nn as nn
from pydantic import PositiveInt, field_validator, model_validator

from clinicadl.utils.factories import get_defaults_from

from .conv_encoder import ConvEncoder, ConvEncoderOptions
from .mlp import MLP, MLPOptions
from .utils.config import NetworkConfig, _InShapeConfig


[docs] class CNN(nn.Sequential): """ A regressor/classifier with first convolutional layers and then fully connected layers. This network is a simple aggregation of a :py:class:`~clinicadl.networks.nn.ConvEncoder` and a :py:class:`~clinicadl.networks.nn.MLP`. Works with 2D or 3D images (with additional batch and channel dimensions). Parameters ---------- in_shape : Sequence[int] Dimensions of the input tensor (without batch dimension). num_outputs : int Number of variables to predict. conv_args : Dict[str, Any] The arguments for the convolutional part. The arguments are those accepted by :py:class:`~clinicadl.networks.nn.ConvEncoder`, except ``spatial_dims`` and ``in_channels`` that are specified here via ``in_shape``. So, the only **mandatory argument is** ``channels``. mlp_args : Optional[Dict[str, Any]], default=None The arguments for the MLP part. The arguments are those accepted by :py:class:`~clinicadl.networks.nn.MLP`, except ``num_inputs`` that is inferred from the output of the convolutional part, and ``num_outputs`` that is set here. So, the only **mandatory argument is** ``hidden_dims``.\n If ``None``, the MLP part will be reduced to a single linear layer. Raises ------ ValueError If ``conv_args`` doesn't contain the key ``channels``. ValueError If ``mlp_args`` is not ``None`` and doesn't contain the key ``hidden_dims``. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.ConvEncoder` :py:class:`~clinicadl.networks.nn.MLP` Examples -------- .. code-block:: python # a classifier >>> CNN( in_shape=(1, 10, 10), num_outputs=2, conv_args={"channels": [2, 4], "norm": None, "act": None}, mlp_args={"hidden_dims": [5], "act": "elu", "norm": None, "output_act": "softmax"}, ) CNN( (convolutions): ConvEncoder( (layer0): Convolution( (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) ) (layer1): Convolution( (conv): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1)) ) ) (mlp): MLP( (flatten): Flatten(start_dim=1, end_dim=-1) (hidden0): Sequential( (linear): Linear(in_features=144, out_features=5, bias=True) (adn): ADN( (A): ELU(alpha=1.0) ) ) (output): Sequential( (linear): Linear(in_features=5, out_features=2, bias=True) (output_act): Softmax(dim=None) ) ) ) .. code-block:: python # a regressor >>> CNN( in_shape=(1, 10, 10), num_outputs=2, conv_args={"channels": [2, 4], "norm": None, "act": None}, ) CNN( (convolutions): ConvEncoder( (layer0): Convolution( (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) ) (layer1): Convolution( (conv): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1)) ) ) (mlp): MLP( (flatten): Flatten(start_dim=1, end_dim=-1) (output): Linear(in_features=144, out_features=2, bias=True) ) ) """ def __init__( self, in_shape: Sequence[int], num_outputs: int, conv_args: dict[str, Any], mlp_args: Optional[dict[str, Any]] = None, ) -> None: super().__init__() self.config = CNNConfig( in_shape=in_shape, num_outputs=num_outputs, conv_args=conv_args, mlp_args=mlp_args, ) in_channels, *input_size = self.config.in_shape spatial_dims = len(input_size) self.convolutions = ConvEncoder( in_channels=in_channels, spatial_dims=spatial_dims, _input_size=tuple(input_size), **self.config.conv_args.to_raw_dict(), ) n_channels = ( self.config.conv_args.channels[-1] if len(self.config.conv_args.channels) > 0 else self.config.in_shape[0] ) flatten_shape = int(np.prod(self.convolutions._final_size) * n_channels) self.mlp = MLP( num_inputs=flatten_shape, num_outputs=self.config.num_outputs, **self.config.mlp_args.to_raw_dict(), )
CNN_DEFAULTS = get_defaults_from(CNN)
[docs] class CNNConfig(NetworkConfig, _InShapeConfig): """ Config class for :py:class:`clinicadl.networks.nn.CNN`. """ in_shape: Sequence[PositiveInt] num_outputs: PositiveInt conv_args: ConvEncoderOptions mlp_args: MLPOptions = CNN_DEFAULTS["mlp_args"] @field_validator("mlp_args", mode="before") @classmethod def _handle_none_mlp_args(cls, v): """ To accept None value for 'mlp_args'. """ if v is None: return MLPOptions(hidden_dims=[]) return v @model_validator(mode="after") def _check_dim(self): _, *input_size = self.in_shape spatial_dims = len(input_size) self.conv_args._check_args_dim(spatial_dims) return self @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return CNN