clinicadl.networks.nn.ViT

class clinicadl.networks.nn.ViT(in_shape: Sequence[int], patch_size: Sequence[int] | int, num_outputs: int | None, embedding_dim: int = 768, num_layers: int = 12, num_heads: int = 12, mlp_dim: int = 3072, pos_embed_type: str | PosEmbedType | None = PosEmbedType.LEARN, output_act: ActFunction | tuple[ActFunction, dict[str, Any]] | None = ActFunction.TANH, dropout: float | None = None) None[source]

Vision Transformer, based on An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.

Adapted from torchvision’s implementation.

The user can customize the patch size, the embedding dimension, the number of transformer blocks, the number of attention heads, as well as other parameters like the type of position embedding.

Works with 2D or 3D images (with additional batch and channel dimensions).

Parameters:
  • in_shape (Sequence[int]) – Dimensions of the input tensor (without batch dimension).

  • patch_size (Union[Sequence[int], int]) – Patch size (without batch and channel dimensions). If int, the same patch size will be used for all dimensions. patch_size must divide in_shape in all spatial dimensions.

  • num_outputs (Optional[int]) –

    Number of output variables after the last linear layer.

    If None, the patch embeddings after the last transformer block will be returned.

  • embedding_dim (int, default=768) – Size of the embedding vectors. Must be divisible by num_heads as each head will be responsible for a part of the embedding vectors. Default to 768, as ViT-Base in the original paper.

  • num_layers (int, default=12) – Number of consecutive transformer blocks. Default to 12, as ViT-Base in the original paper.

  • num_heads (int, default=12) – Number of heads in the self-attention blocks. Must divide embedding_dim. Default to 12, as ViT-Base in the original paper.

  • mlp_dim (int, default=3072) – Size of the hidden layer in the MLP part of the transformer block. Default to 3072, as ViT-Base in the original paper.

  • pos_embed_type (Optional[Union[str, PosEmbedType]], default="learnable") –

    Type of position embedding. Can be either learnable, sincos or None:

    • learnable: the position embeddings are parameters that will be learned during the training process.

    • sincos: the position embeddings are fixed and determined with sinus and cosinus formulas described in Vaswani et al.[1]. Only implemented for 2D and 3D images. With sincos position embedding, embedding_dim must be divisible by 4 for 2D images, and by 6 for 3D images.

    • None: no position embeddings are used.

    Default to learnable, as in the original paper.

  • output_act (Optional[ActivationParameters], default="tanh") –

    A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as activation_name or (activation_name, arguments), where arguments is a dictionary. If None, no activation will be used.

    activation_name can be any value in {"celu", "elu", "gelu", "leakyrelu", "logsoftmax", "mish", "prelu", "relu", "relu6", "selu", "sigmoid", "softmax", "tanh"}. Please refer to PyTorch activation functions to know the arguments for each of them.

    Default is tanh, as in the original paper.

  • dropout (Optional[float], default=None) – Dropout ratio. If None, no dropout.

See also

torch.nn.Module

To see all the methods of this neural network.

Examples

>>> ViT(
        in_shape=(3, 60, 64),
        patch_size=4,
        num_outputs=2,
        embedding_dim=32,
        num_layers=2,
        num_heads=4,
        mlp_dim=128,
        output_act="softmax",
    )
ViT(
    (conv_proj): Conv2d(3, 32, kernel_size=(4, 4), stride=(4, 4))
    (encoder): Encoder(
        (dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
            (0-1): 2 x EncoderBlock(
                (norm1): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
                (self_attention): MultiheadAttention(
                    (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
                )
                (dropout): Dropout(p=0.0, inplace=False)
                (norm2): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
                (mlp): MLPBlock(
                    (0): Linear(in_features=32, out_features=128, bias=True)
                    (1): GELU(approximate='none')
                    (2): Dropout(p=0.0, inplace=False)
                    (3): Linear(in_features=128, out_features=32, bias=True)
                    (4): Dropout(p=0.0, inplace=False)
                )
            )
        )
        (norm): LayerNorm((32,), eps=1e-06, elementwise_affine=True)
    )
    (fc): Sequential(
        (out): Linear(in_features=32, out_features=2, bias=True)
        (output_act): Softmax(dim=None)
    )
)

References

forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.