clinicadl.networks.nn.ViT¶
- class clinicadl.networks.nn.ViT(in_shape: Sequence[int], patch_size: Sequence[int] | int, num_outputs: int | None, embedding_dim: int = 768, num_layers: int = 12, num_heads: int = 12, mlp_dim: int = 3072, pos_embed_type: str | PosEmbedType | None = PosEmbedType.LEARN, output_act: ActFunction | tuple[ActFunction, dict[str, Any]] | None = ActFunction.TANH, dropout: float | None = None) None[source]¶
Vision Transformer, based on An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.
Adapted from torchvision’s implementation.
The user can customize the patch size, the embedding dimension, the number of transformer blocks, the number of attention heads, as well as other parameters like the type of position embedding.
Works with 2D or 3D images (with additional batch and channel dimensions).
- Parameters:
in_shape (Sequence[int]) – Dimensions of the input tensor (without batch dimension).
patch_size (Union[Sequence[int], int]) – Patch size (without batch and channel dimensions). If
int, the same patch size will be used for all dimensions.patch_sizemust dividein_shapein all spatial dimensions.num_outputs (Optional[int]) –
Number of output variables after the last linear layer.
If
None, the patch embeddings after the last transformer block will be returned.embedding_dim (int, default=768) – Size of the embedding vectors. Must be divisible by
num_headsas each head will be responsible for a part of the embedding vectors. Default to768, asViT-Basein the original paper.num_layers (int, default=12) – Number of consecutive transformer blocks. Default to
12, asViT-Basein the original paper.num_heads (int, default=12) – Number of heads in the self-attention blocks. Must divide
embedding_dim. Default to12, asViT-Basein the original paper.mlp_dim (int, default=3072) – Size of the hidden layer in the MLP part of the transformer block. Default to
3072, asViT-Basein the original paper.pos_embed_type (Optional[Union[str, PosEmbedType]], default="learnable") –
Type of position embedding. Can be either
learnable,sincosorNone:learnable: the position embeddings are parameters that will be learned during the training process.sincos: the position embeddings are fixed and determined with sinus and cosinus formulas described in Vaswani et al.[1]. Only implemented for 2D and 3D images. Withsincosposition embedding,embedding_dimmust be divisible by4for 2D images, and by6for 3D images.None: no position embeddings are used.
Default to
learnable, as in the original paper.output_act (Optional[ActivationParameters], default="tanh") –
A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as
activation_nameor(activation_name, arguments), whereargumentsis a dictionary. IfNone, no activation will be used.activation_namecan be any value in {"celu","elu","gelu","leakyrelu","logsoftmax","mish","prelu","relu","relu6","selu","sigmoid","softmax","tanh"}. Please refer to PyTorch activation functions to know the arguments for each of them.Default is
tanh, as in the original paper.dropout (Optional[float], default=None) – Dropout ratio. If
None, no dropout.
See also
torch.nn.ModuleTo see all the methods of this neural network.
Examples
>>> ViT( in_shape=(3, 60, 64), patch_size=4, num_outputs=2, embedding_dim=32, num_layers=2, num_heads=4, mlp_dim=128, output_act="softmax", ) ViT( (conv_proj): Conv2d(3, 32, kernel_size=(4, 4), stride=(4, 4)) (encoder): Encoder( (dropout): Dropout(p=0.0, inplace=False) (layers): ModuleList( (0-1): 2 x EncoderBlock( (norm1): LayerNorm((32,), eps=1e-06, elementwise_affine=True) (self_attention): MultiheadAttention( (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True) ) (dropout): Dropout(p=0.0, inplace=False) (norm2): LayerNorm((32,), eps=1e-06, elementwise_affine=True) (mlp): MLPBlock( (0): Linear(in_features=32, out_features=128, bias=True) (1): GELU(approximate='none') (2): Dropout(p=0.0, inplace=False) (3): Linear(in_features=128, out_features=32, bias=True) (4): Dropout(p=0.0, inplace=False) ) ) ) (norm): LayerNorm((32,), eps=1e-06, elementwise_affine=True) ) (fc): Sequential( (out): Linear(in_features=32, out_features=2, bias=True) (output_act): Softmax(dim=None) ) )
References
- forward(x: Tensor) Tensor[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.