clinicadl.networks.nn.Generator

class clinicadl.networks.nn.Generator(latent_size: int, start_shape: Sequence[int], conv_args: Dict[str, Any], mlp_args: Dict[str, Any] | None = None) None[source]

A generator with first fully-connected layers and then convolutional layers.

This network is a simple aggregation of a MLP and a ConvDecoder.

Works with 2D or 3D images (with additional batch and channel dimensions).

Parameters:
  • latent_size (int) – Size of the latent vector.

  • start_shape (Sequence[int]) –

    Initial shape of the image, i.e. the shape at the beginning of the convolutional part (without batch dimension, but including the channel dimension).

    Thus, start_shape also determines the dimension of the output of the generator (the exact shape depends on the convolutional part and can be accessed via the attribute output_shape).

  • conv_args (Dict[str, Any]) – The arguments for the convolutional part. The arguments are those accepted by ConvDecoder, except spatial_dims and in_channels that are specified here via start_shape. So, the only mandatory argument is channels.

  • mlp_args (Optional[Dict[str, Any]], default=None) –

    The arguments for the MLP part. The arguments are those accepted by MLP, except num_inputs that is equal here to latent_size, and num_outputs that is inferred here from start_shape. So, the only mandatory argument is hidden_dims.

    If None, the MLP part will be reduced to a single linear layer.

Attributes:

output_shape (int) – The shape of the output image, computed from start_shape.

Raises:
  • ValueError – If conv_args doesn’t contain the key channels.

  • ValueError – If mlp_args is not None and doesn’t contain the key hidden_dims.

See also

torch.nn.Module

To see all the methods of this neural network.

ConvDecoder, MLP

Examples

>>> Generator(
        latent_size=8,
        start_shape=(8, 2, 2),
        conv_args={"channels": [4, 2], "norm": None, "act": None},
        mlp_args={"hidden_dims": [16], "act": "elu", "norm": None},
    )
Generator(
    (mlp): MLP(
        (flatten): Flatten(start_dim=1, end_dim=-1)
        (hidden0): Sequential(
            (linear): Linear(in_features=8, out_features=16, bias=True)
            (adn): ADN(
                (A): ELU(alpha=1.0)
            )
        )
        (output): Linear(in_features=16, out_features=32, bias=True)
    )
    (reshape): Reshape()
    (convolutions): ConvDecoder(
        (layer0): Convolution(
            (conv): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(1, 1))
        )
        (layer1): Convolution(
            (conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1))
        )
    )
)

>>> Generator(
        latent_size=8,
        start_shape=(8, 2, 2),
        conv_args={"channels": [4, 2], "norm": None, "act": None, "output_act": "relu"},
    )
Generator(
    (mlp): MLP(
        (flatten): Flatten(start_dim=1, end_dim=-1)
        (output): Linear(in_features=8, out_features=32, bias=True)
    )
    (reshape): Reshape()
    (convolutions): ConvDecoder(
        (layer0): Convolution(
            (conv): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(1, 1))
        )
        (layer1): Convolution(
            (conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1))
        )
        (output_act): ReLU()
    )
)