from typing import Any, Dict, Optional, Sequence
import numpy as np
import torch.nn as nn
from monai.networks.layers.simplelayers import Reshape
from pydantic import PositiveInt, field_validator, model_validator
from clinicadl.utils.factories import get_defaults_from
from .conv_decoder import ConvDecoder, ConvDecoderOptions
from .mlp import MLP, MLPOptions
from .utils.config import NetworkConfig
[docs]
class Generator(nn.Sequential):
"""
A generator with first fully-connected layers and then convolutional layers.
This network is a simple aggregation of a :py:class:`~clinicadl.networks.nn.MLP`
and a :py:class:`~clinicadl.networks.nn.ConvDecoder`.
Works with 2D or 3D images (with additional batch and channel dimensions).
Parameters
----------
latent_size : int
Size of the latent vector.
start_shape : Sequence[int]
Initial shape of the image, i.e. the shape at the
beginning of the convolutional part (without batch dimension, but including the channel dimension).\n
Thus, ``start_shape`` also determines the dimension of the output of the generator (the exact
shape depends on the convolutional part and can be accessed via the attribute
``output_shape``).
conv_args : Dict[str, Any]
The arguments for the convolutional part. The arguments are those accepted by
:py:class:`~clinicadl.networks.nn.ConvDecoder`, except ``spatial_dims`` and ``in_channels``
that are specified here via ``start_shape``. So, the only **mandatory argument is** ``channels``.
mlp_args : Optional[Dict[str, Any]], default=None
The arguments for the MLP part. The arguments are those accepted by
:py:class:`~clinicadl.networks.nn.MLP`, except ``num_inputs`` that is equal here to
``latent_size``, and ``num_outputs`` that is inferred here from ``start_shape``.
So, the only **mandatory argument is** ``hidden_dims``.\n
If ``None``, the MLP part will be reduced to a single linear layer.
Attributes
----------
output_shape : int
The shape of the output image, computed from ``start_shape``.
Raises
------
ValueError
If ``conv_args`` doesn't contain the key ``channels``.
ValueError
If ``mlp_args`` is not ``None`` and doesn't contain the key ``hidden_dims``.
See Also
--------
:py:class:`torch.nn.Module`
To see all the methods of this neural network.
:py:class:`~clinicadl.networks.nn.ConvDecoder`
:py:class:`~clinicadl.networks.nn.MLP`
Examples
--------
.. code-block:: python
>>> Generator(
latent_size=8,
start_shape=(8, 2, 2),
conv_args={"channels": [4, 2], "norm": None, "act": None},
mlp_args={"hidden_dims": [16], "act": "elu", "norm": None},
)
Generator(
(mlp): MLP(
(flatten): Flatten(start_dim=1, end_dim=-1)
(hidden0): Sequential(
(linear): Linear(in_features=8, out_features=16, bias=True)
(adn): ADN(
(A): ELU(alpha=1.0)
)
)
(output): Linear(in_features=16, out_features=32, bias=True)
)
(reshape): Reshape()
(convolutions): ConvDecoder(
(layer0): Convolution(
(conv): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(1, 1))
)
(layer1): Convolution(
(conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1))
)
)
)
>>> Generator(
latent_size=8,
start_shape=(8, 2, 2),
conv_args={"channels": [4, 2], "norm": None, "act": None, "output_act": "relu"},
)
Generator(
(mlp): MLP(
(flatten): Flatten(start_dim=1, end_dim=-1)
(output): Linear(in_features=8, out_features=32, bias=True)
)
(reshape): Reshape()
(convolutions): ConvDecoder(
(layer0): Convolution(
(conv): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(1, 1))
)
(layer1): Convolution(
(conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1))
)
(output_act): ReLU()
)
)
"""
def __init__(
self,
latent_size: int,
start_shape: Sequence[int],
conv_args: Dict[str, Any],
mlp_args: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self.config = GeneratorConfig(
latent_size=latent_size,
start_shape=start_shape,
conv_args=conv_args,
mlp_args=mlp_args,
)
flatten_shape = int(np.prod(self.config.start_shape))
self.mlp = MLP(
num_inputs=self.config.latent_size,
num_outputs=flatten_shape,
**self.config.mlp_args.to_raw_dict(),
)
self.reshape = Reshape(*self.config.start_shape)
inter_channels, *inter_size = self.config.start_shape
self.convolutions = ConvDecoder(
in_channels=inter_channels,
spatial_dims=len(inter_size),
_input_size=inter_size,
**self.config.conv_args.to_raw_dict(),
)
n_channels = (
self.config.conv_args.channels[-1]
if len(self.config.conv_args.channels) > 0
else self.config.start_shape[0]
)
self.output_shape = (n_channels, *self.convolutions._final_size)
GENERATOR_DEFAULTS = get_defaults_from(Generator)
[docs]
class GeneratorConfig(NetworkConfig):
"""
Config class for :py:class:`clinicadl.networks.nn.Generator`.
"""
latent_size: PositiveInt
start_shape: Sequence[PositiveInt]
conv_args: ConvDecoderOptions
mlp_args: MLPOptions = GENERATOR_DEFAULTS["mlp_args"]
@field_validator("start_shape", mode="after")
@classmethod
def _start_shape_validator(cls, v):
"""Checks that 'start_shape' corresponds to 1D, 2D or 3D images."""
assert (
2 <= len(v) <= 4
), f"'start_shape' must be of length 2 (1D), 3 (2D image) or 4 (3D image). Don't forget the channel dimension. Got: {v}."
return v
@field_validator("mlp_args", mode="before")
@classmethod
def _handle_none_mlp_args(cls, v):
"""
To accept None value for 'mlp_args'.
"""
if v is None:
return MLPOptions(hidden_dims=[])
return v
@model_validator(mode="after")
def _check_dim(self):
_, *inter_size = self.start_shape
spatial_dims = len(inter_size)
self.conv_args._check_args_dim(spatial_dims)
return self
@classmethod
def _get_class(cls) -> type[nn.Module]:
"""Returns the network associated to this config class."""
return Generator