clinicadl.networks.nn.Generator¶
- class clinicadl.networks.nn.Generator(latent_size: int, start_shape: Sequence[int], conv_args: Dict[str, Any], mlp_args: Dict[str, Any] | None = None) None[source]¶
A generator with first fully-connected layers and then convolutional layers.
This network is a simple aggregation of a
MLPand aConvDecoder.Works with 2D or 3D images (with additional batch and channel dimensions).
- Parameters:
latent_size (int) – Size of the latent vector.
start_shape (Sequence[int]) –
Initial shape of the image, i.e. the shape at the beginning of the convolutional part (without batch dimension, but including the channel dimension).
Thus,
start_shapealso determines the dimension of the output of the generator (the exact shape depends on the convolutional part and can be accessed via the attributeoutput_shape).conv_args (Dict[str, Any]) – The arguments for the convolutional part. The arguments are those accepted by
ConvDecoder, exceptspatial_dimsandin_channelsthat are specified here viastart_shape. So, the only mandatory argument ischannels.mlp_args (Optional[Dict[str, Any]], default=None) –
The arguments for the MLP part. The arguments are those accepted by
MLP, exceptnum_inputsthat is equal here tolatent_size, andnum_outputsthat is inferred here fromstart_shape. So, the only mandatory argument ishidden_dims.If
None, the MLP part will be reduced to a single linear layer.
- Attributes:
output_shape (int) – The shape of the output image, computed from
start_shape.- Raises:
ValueError – If
conv_argsdoesn’t contain the keychannels.ValueError – If
mlp_argsis notNoneand doesn’t contain the keyhidden_dims.
Examples
>>> Generator( latent_size=8, start_shape=(8, 2, 2), conv_args={"channels": [4, 2], "norm": None, "act": None}, mlp_args={"hidden_dims": [16], "act": "elu", "norm": None}, ) Generator( (mlp): MLP( (flatten): Flatten(start_dim=1, end_dim=-1) (hidden0): Sequential( (linear): Linear(in_features=8, out_features=16, bias=True) (adn): ADN( (A): ELU(alpha=1.0) ) ) (output): Linear(in_features=16, out_features=32, bias=True) ) (reshape): Reshape() (convolutions): ConvDecoder( (layer0): Convolution( (conv): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(1, 1)) ) (layer1): Convolution( (conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1)) ) ) ) >>> Generator( latent_size=8, start_shape=(8, 2, 2), conv_args={"channels": [4, 2], "norm": None, "act": None, "output_act": "relu"}, ) Generator( (mlp): MLP( (flatten): Flatten(start_dim=1, end_dim=-1) (output): Linear(in_features=8, out_features=32, bias=True) ) (reshape): Reshape() (convolutions): ConvDecoder( (layer0): Convolution( (conv): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(1, 1)) ) (layer1): Convolution( (conv): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(1, 1)) ) (output_act): ReLU() ) )