clinicadl.networks.nn.VAE¶
- class clinicadl.networks.nn.VAE(in_shape: Sequence[int], latent_size: int, conv_args: Dict[str, Any], mlp_args: Dict[str, Any] | None = None, out_channels: int | None = None, output_act: ActFunction | tuple[ActFunction, dict[str, Any]] | None = None, unpooling_mode: str | UnpoolingMode = UnpoolingMode.NEAREST) None[source]¶
A Variational AutoEncoder with convolutional and fully connected layers.
The user must pass the arguments to build an encoder, from its convolutional and fully connected parts, and the decoder will be automatically built by taking the symmetrical network.
More precisely, to build the decoder, the order of the encoding layers is reverted, convolutions are replaced by transposed convolutions, and pooling layers are replaced by either upsampling or transposed convolution layers.
A
VAEis very similar to aAutoEncoder, except that the last layer of the MLP part is duplicated to infer both the mean and the log variance. Besides, to sample from the latent distribution, the reparametrization trick is performed withreparameterize().Works with 2D or 3D images (with additional batch and channel dimensions).
Note
Please note that the order of Activation, Dropout and Normalization, defined with the argument
adn_orderinginconv_args, is the same for the encoder and the decoder.- Parameters:
in_shape (Sequence[int]) – Dimensions of the input tensor (without batch dimension).
latent_size (int) – Size of the latent vector.
conv_args (Dict[str, Any]) – The arguments for the convolutional part. The arguments are those accepted by
ConvEncoder, exceptspatial_dimsandin_channelsthat are specified here viain_shape. So, the only mandatory argument ischannels.mlp_args (Optional[Dict[str, Any]], default=None) –
The arguments for the MLP part. The arguments are those accepted by
MLP, exceptnum_inputsthat is inferred from the output of the convolutional part, andnum_outputsthat is equal tolatent_sizehere. So, the only mandatory argument ishidden_dims.If
None, the MLP part will be reduced to a single linear layer.The last linear layer will be duplicated to infer both the mean and the log variance.
out_channels (Optional[int], default=None) – Number of output channels. If
None, the output will have the same number of channels as the input.output_act (Optional[ActivationParameters], default=None) –
A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as
activation_nameor(activation_name, arguments), whereargumentsis a dictionary. IfNone, no activation will be used.activation_namecan be any value in {"celu","elu","gelu","leakyrelu","logsoftmax","mish","prelu","relu","relu6","selu","sigmoid","softmax","tanh"}. Please refer to PyTorch activation functions to know the arguments for each of them.unpooling_mode (Union[str, UnpoolingMode], default=UnpoolingMode.NEAREST) –
Type of unpooling. Can be any value in {
"nearest","linear","bilinear","bicubic","trilinear"or"convtranspose"}:nearest: unpooling is performed by upsampling with the nearest algorithm (seetorch.nn.Upsample);linear: unpooling is performed by upsampling with the linear algorithm. Only works with 1D images (excluding the channel dimension);bilinear: unpooling is performed by upsampling with the bilinear algorithm. Only works with 2D images;bicubic: unpooling is performed by upsampling with the bicubic algorithm. Only works with 2D images;trilinear: unpooling is performed by upsampling with the trilinear algorithm. Only works with 3D images;convtranspose: unpooling is performed with a transposed convolution (seetorch.nn.ConvTranspose3d), whose parameters (kernel size, stride, etc.) are computed to reverse the pooling operation.
Examples
>>> VAE( in_shape=(1, 16, 16), latent_size=4, conv_args={"channels": [2]}, mlp_args={"hidden_dims": [16], "output_act": "relu"}, out_channels=2, output_act="sigmoid", unpooling_mode="bilinear", ) VAE( (encoder): CNN( (convolutions): ConvEncoder( (layer0): Convolution( (conv): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1)) ) ) (mlp): MLP( (flatten): Flatten(start_dim=1, end_dim=-1) (hidden0): Sequential( (linear): Linear(in_features=392, out_features=16, bias=True) (adn): ADN( (N): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (A): PReLU(num_parameters=1) ) ) (output): Identity() ) ) (mu): Sequential( (linear): Linear(in_features=16, out_features=4, bias=True) (output_act): ReLU() ) (log_var): Sequential( (linear): Linear(in_features=16, out_features=4, bias=True) (output_act): ReLU() ) (decoder): Generator( (mlp): MLP( (flatten): Flatten(start_dim=1, end_dim=-1) (hidden0): Sequential( (linear): Linear(in_features=4, out_features=16, bias=True) (adn): ADN( (N): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (A): PReLU(num_parameters=1) ) ) (output): Sequential( (linear): Linear(in_features=16, out_features=392, bias=True) (output_act): ReLU() ) ) (reshape): Reshape() (convolutions): ConvDecoder( (layer0): Convolution( (conv): ConvTranspose2d(2, 2, kernel_size=(3, 3), stride=(1, 1)) ) (output_act): Sigmoid() ) ) )