Source code for clinicadl.networks.nn.vit

import math
import re
from collections import OrderedDict
from copy import deepcopy
from enum import Enum
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
from monai.networks.layers import Conv
from monai.networks.layers.utils import get_act_layer
from pydantic import NonNegativeFloat, PositiveInt, model_validator
from torch.hub import load_state_dict_from_url
from torchvision.models.vision_transformer import (
    ViT_B_16_Weights,
    ViT_B_32_Weights,
    ViT_L_16_Weights,
    ViT_L_32_Weights,
)

from clinicadl.utils.factories import get_defaults_from

from .layers.utils import ActFunction, ActivationParameters
from .layers.vit import Encoder
from .utils import ensure_tuple
from .utils.config import (
    NetworkConfig,
    _DropoutConfig,
    _InShapeConfig,
)

__all__ = [
    "ViT",
    "ViTB16",
    "ViTB32",
    "ViTL16",
    "ViTL32",
]


class PosEmbedType(str, Enum):
    """Available position embedding types for ViT."""

    LEARN = "learnable"
    SINCOS = "sincos"


[docs] class ViT(nn.Module): """ Vision Transformer, based on `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_. Adapted from :torchvision:`torchvision's implementation <models/vision_transformer.html>`. The user can customize the patch size, the embedding dimension, the number of transformer blocks, the number of attention heads, as well as other parameters like the type of position embedding. Works with 2D or 3D images (with additional batch and channel dimensions). Parameters ---------- in_shape : Sequence[int] Dimensions of the input tensor (without batch dimension). patch_size : Union[Sequence[int], int] Patch size (without batch and channel dimensions). If ``int``, the same patch size will be used for all dimensions. ``patch_size`` must divide ``in_shape`` in all spatial dimensions. num_outputs : Optional[int] Number of output variables after the last linear layer.\n If ``None``, the patch embeddings after the last transformer block will be returned. embedding_dim : int, default=768 Size of the embedding vectors. Must be divisible by ``num_heads`` as each head will be responsible for a part of the embedding vectors. Default to ``768``, as ``ViT-Base`` in the original paper. num_layers : int, default=12 Number of consecutive transformer blocks. Default to ``12``, as ``ViT-Base`` in the original paper. num_heads : int, default=12 Number of heads in the self-attention blocks. Must divide ``embedding_dim``. Default to ``12``, as ``ViT-Base`` in the original paper. mlp_dim : int, default=3072 Size of the hidden layer in the MLP part of the transformer block. Default to ``3072``, as ``ViT-Base`` in the original paper. pos_embed_type : Optional[Union[str, PosEmbedType]], default="learnable" Type of position embedding. Can be either ``learnable``, ``sincos`` or ``None``: - ``learnable``: the position embeddings are parameters that will be learned during the training process. - ``sincos``: the position embeddings are fixed and determined with sinus and cosinus formulas described in :footcite:t:`Vaswani2023`. Only implemented for 2D and 3D images. With ``sincos`` position embedding, ``embedding_dim`` must be divisible by ``4`` for 2D images, and by ``6`` for 3D images. - ``None``: no position embeddings are used. Default to ``learnable``, as in the original paper. output_act : Optional[ActivationParameters], default="tanh" A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary. If ``None``, no activation will be used.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them.\n Default is ``tanh``, as in the original paper. dropout : Optional[float], default=None Dropout ratio. If ``None``, no dropout. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. Examples -------- .. code-block:: python >>> ViT( in_shape=(3, 60, 64), patch_size=4, num_outputs=2, embedding_dim=32, num_layers=2, num_heads=4, mlp_dim=128, output_act="softmax", ) ViT( (conv_proj): Conv2d(3, 32, kernel_size=(4, 4), stride=(4, 4)) (encoder): Encoder( (dropout): Dropout(p=0.0, inplace=False) (layers): ModuleList( (0-1): 2 x EncoderBlock( (norm1): LayerNorm((32,), eps=1e-06, elementwise_affine=True) (self_attention): MultiheadAttention( (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True) ) (dropout): Dropout(p=0.0, inplace=False) (norm2): LayerNorm((32,), eps=1e-06, elementwise_affine=True) (mlp): MLPBlock( (0): Linear(in_features=32, out_features=128, bias=True) (1): GELU(approximate='none') (2): Dropout(p=0.0, inplace=False) (3): Linear(in_features=128, out_features=32, bias=True) (4): Dropout(p=0.0, inplace=False) ) ) ) (norm): LayerNorm((32,), eps=1e-06, elementwise_affine=True) ) (fc): Sequential( (out): Linear(in_features=32, out_features=2, bias=True) (output_act): Softmax(dim=None) ) ) References ---------- .. footbibliography:: """ def __init__( self, in_shape: Sequence[int], patch_size: Union[Sequence[int], int], num_outputs: Optional[int], embedding_dim: int = 768, num_layers: int = 12, num_heads: int = 12, mlp_dim: int = 3072, pos_embed_type: Optional[Union[str, PosEmbedType]] = PosEmbedType.LEARN, output_act: Optional[ActivationParameters] = ActFunction.TANH, dropout: Optional[float] = None, ) -> None: super().__init__() self.config = ViTConfig( in_shape=in_shape, patch_size=patch_size, num_outputs=num_outputs, embedding_dim=embedding_dim, num_layers=num_layers, num_heads=num_heads, mlp_dim=mlp_dim, pos_embed_type=pos_embed_type, output_act=output_act, dropout=dropout, ) self.in_channels, *self.img_size = self.config.in_shape self.spatial_dims = len(self.img_size) self.classification = True if self.config.num_outputs else False dropout = self.config.dropout or 0.0 self.conv_proj = Conv[Conv.CONV, self.spatial_dims]( # pylint: disable=not-callable in_channels=self.in_channels, out_channels=self.config.embedding_dim, kernel_size=self.config.patch_size, stride=self.config.patch_size, ) self.seq_length = int( np.prod(np.array(self.img_size) // np.array(self.config.patch_size)) ) # Add a class token if self.classification: self.class_token = nn.Parameter( torch.zeros(1, 1, self.config.embedding_dim) ) self.seq_length += 1 pos_embedding = self._get_pos_embedding(self.config.pos_embed_type) self.encoder = Encoder( self.seq_length, self.config.num_layers, self.config.num_heads, self.config.embedding_dim, self.config.mlp_dim, dropout=dropout, attention_dropout=dropout, pos_embedding=pos_embedding, ) if self.classification: self.class_token = nn.Parameter( torch.zeros(1, 1, self.config.embedding_dim) ) self.fc = nn.Sequential( OrderedDict( [ ( "out", nn.Linear( self.config.embedding_dim, self.config.num_outputs ), ) ] ) ) self.fc.output_act = get_act_layer(output_act) if output_act else None else: self.fc = None self._init_layers()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv_proj(x) # (n, hidden_dim, n_h, n_w) -> (n, (h * w * d), hidden_dim) x = x.flatten(2).transpose(-1, -2) n = x.shape[0] # Expand the class token to the full batch if self.fc: batch_class_token = self.class_token.expand(n, -1, -1) x = torch.cat([batch_class_token, x], dim=1) x = self.encoder(x) # Classifier "token" as used by standard language architectures if self.fc: x = x[:, 0] x = self.fc(x) return x
def _get_pos_embedding( self, pos_embed_type: Optional[PosEmbedType] ) -> Optional[nn.Parameter]: """ Gets position embeddings. If `pos_embed_type` is "learnable", will return None as it will be handled by the encoder module. """ if pos_embed_type is None: pos_embed = nn.Parameter( torch.zeros(1, self.seq_length, self.config.embedding_dim) ) pos_embed.requires_grad = False return pos_embed if pos_embed_type == PosEmbedType.LEARN: return None # will be initialized inside the Encoder elif pos_embed_type == PosEmbedType.SINCOS: grid_size = [] for in_size, pa_size in zip(self.img_size, self.config.patch_size): grid_size.append(in_size // pa_size) pos_embed = build_sincos_position_embedding( grid_size, self.config.embedding_dim, self.spatial_dims ) if self.classification: pos_embed = torch.nn.Parameter( torch.cat( [torch.zeros(1, 1, self.config.embedding_dim), pos_embed], dim=1 ) ) # add 0 for class token pos embedding pos_embed.requires_grad = False return pos_embed def _init_layers(self): """ Initializes some layers, based on torchvision's implementation: https://pytorch.org/vision/main/ _modules/torchvision/models/vision_transformer.html """ fan_in = self.conv_proj.in_channels * np.prod(self.conv_proj.kernel_size) nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) nn.init.zeros_(self.conv_proj.bias) def _load_weights(self, url: str) -> None: """To load weights from torchvision.""" pretrained_dict = load_state_dict_from_url(url, progress=True) if not self.classification: del pretrained_dict["class_token"] pretrained_dict["encoder.pos_embedding"] = pretrained_dict[ "encoder.pos_embedding" ][:, 1:] # remove class token position embedding fc_layers = deepcopy(self.fc) self.fc = None self.load_state_dict(_state_dict_adapter(pretrained_dict)) self.fc = fc_layers
[docs] class ViTB16(ViT): """ ViT-B/16, from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__. Only the last fully connected layer will be changed to match ``num_outputs``. The user can use the pretrained models from ``torchvision``. Note that the last fully connected layer will not use pretrained weights, as it is task specific. .. warning:: Only works with **2D images of size (224, 224), with 3 channels**. Parameters ---------- num_outputs : Optional[int] Number of output variables after the last linear layer. If ``None``, the feature map before the last fully connected layer will be returned. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary. If ``None``, no activation will be used.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them. pretrained : bool, default=False Whether to use pretrained weights. The pretrained weights used are the default ones from :py:func:`torchvision.models.vit_b_16`. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.ViT` """ def __init__( self, num_outputs: Optional[int], output_act: Optional[ActivationParameters] = None, pretrained: bool = False, ) -> None: config = ViTB16Config( num_outputs=num_outputs, output_act=output_act, pretrained=pretrained ) super().__init__( in_shape=(3, 224, 224), patch_size=16, num_outputs=config.num_outputs, embedding_dim=768, mlp_dim=3072, num_heads=12, num_layers=12, output_act=config.output_act, ) if config.pretrained: self._load_weights(ViT_B_16_Weights.DEFAULT.url)
[docs] class ViTB32(ViT): """ ViT-B/32, from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__. Only the last fully connected layer will be changed to match ``num_outputs``. The user can use the pretrained models from ``torchvision``. Note that the last fully connected layer will not use pretrained weights, as it is task specific. .. warning:: Only works with **2D images of size (224, 224), with 3 channels**. Parameters ---------- num_outputs : Optional[int] Number of output variables after the last linear layer. If ``None``, the feature map before the last fully connected layer will be returned. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary. If ``None``, no activation will be used.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them. pretrained : bool, default=False Whether to use pretrained weights. The pretrained weights used are the default ones from :py:func:`torchvision.models.vit_b_32`. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.ViT` """ def __init__( self, num_outputs: Optional[int], output_act: Optional[ActivationParameters] = None, pretrained: bool = False, ) -> None: config = ViTB32Config( num_outputs=num_outputs, output_act=output_act, pretrained=pretrained ) super().__init__( in_shape=(3, 224, 224), patch_size=32, num_outputs=config.num_outputs, embedding_dim=768, mlp_dim=3072, num_heads=12, num_layers=12, output_act=config.output_act, ) if config.pretrained: self._load_weights(ViT_B_32_Weights.DEFAULT.url)
[docs] class ViTL16(ViT): """ ViT-L/16, from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__. Only the last fully connected layer will be changed to match ``num_outputs``. The user can use the pretrained models from ``torchvision``. Note that the last fully connected layer will not use pretrained weights, as it is task specific. .. warning:: Only works with **2D images of size (224, 224), with 3 channels**. Parameters ---------- num_outputs : Optional[int] Number of output variables after the last linear layer. If ``None``, the feature map before the last fully connected layer will be returned. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary. If ``None``, no activation will be used.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them. pretrained : bool, default=False Whether to use pretrained weights. The pretrained weights used are the default ones from :py:func:`torchvision.models.vit_l_16`. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.ViT` """ def __init__( self, num_outputs: Optional[int], output_act: Optional[ActivationParameters] = None, pretrained: bool = False, ) -> None: config = ViTL16Config( num_outputs=num_outputs, output_act=output_act, pretrained=pretrained ) super().__init__( in_shape=(3, 224, 224), patch_size=16, num_outputs=config.num_outputs, embedding_dim=1024, mlp_dim=4096, num_heads=16, num_layers=24, output_act=output_act, ) if config.pretrained: self._load_weights(ViT_L_16_Weights.DEFAULT.url)
[docs] class ViTL32(ViT): """ ViT-L/32, from `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__. Only the last fully connected layer will be changed to match ``num_outputs``. The user can use the pretrained models from ``torchvision``. Note that the last fully connected layer will not use pretrained weights, as it is task specific. .. warning:: Only works with **2D images of size (224, 224), with 3 channels**. Parameters ---------- num_outputs : Optional[int] Number of output variables after the last linear layer. If ``None``, the feature map before the last fully connected layer will be returned. output_act : Optional[ActivationParameters], default=None A potential activation layer applied to the output of the network, and optionally its arguments. Must be passed as ``activation_name`` or ``(activation_name, arguments)``, where ``arguments`` is a dictionary. If ``None``, no activation will be used.\n ``activation_name`` can be any value in {``"celu"``, ``"elu"``, ``"gelu"``, ``"leakyrelu"``, ``"logsoftmax"``, ``"mish"``, ``"prelu"``, ``"relu"``, ``"relu6"``, ``"selu"``, ``"sigmoid"``, ``"softmax"``, ``"tanh"``}. Please refer to :torch:`PyTorch activation functions <nn.html#non-linear-activations-weighted-sum-nonlinearity>` to know the arguments for each of them. pretrained : bool, default=False Whether to use pretrained weights. The pretrained weights used are the default ones from :py:func:`torchvision.models.vit_l_32`. See Also -------- :py:class:`torch.nn.Module` To see all the methods of this neural network. :py:class:`~clinicadl.networks.nn.ViT` """ def __init__( self, num_outputs: Optional[int], output_act: Optional[ActivationParameters] = None, pretrained: bool = False, ) -> None: config = ViTL32Config( num_outputs=num_outputs, output_act=output_act, pretrained=pretrained ) super().__init__( in_shape=(3, 224, 224), patch_size=32, num_outputs=config.num_outputs, embedding_dim=1024, mlp_dim=4096, num_heads=16, num_layers=24, output_act=config.output_act, ) if config.pretrained: self._load_weights(ViT_L_32_Weights.DEFAULT.url)
VIT_DEFAULTS = get_defaults_from(ViT) VIT_B_16_DEFAULTS = get_defaults_from(ViTB16) VIT_B_32_DEFAULTS = get_defaults_from(ViTB32) VIT_L_16_DEFAULTS = get_defaults_from(ViTL16) VIT_L_32_DEFAULTS = get_defaults_from(ViTL32)
[docs] class ViTConfig( NetworkConfig, _InShapeConfig, _DropoutConfig, ): """ Config class for :py:class:`clinicadl.networks.nn.ViT`. """ in_shape: Sequence[PositiveInt] patch_size: Union[Sequence[PositiveInt], PositiveInt] num_outputs: Optional[PositiveInt] embedding_dim: PositiveInt = VIT_DEFAULTS["embedding_dim"] num_layers: PositiveInt = VIT_DEFAULTS["num_layers"] num_heads: PositiveInt = VIT_DEFAULTS["num_heads"] mlp_dim: PositiveInt = VIT_DEFAULTS["mlp_dim"] pos_embed_type: Optional[PosEmbedType] = VIT_DEFAULTS["pos_embed_type"] output_act: Optional[ActivationParameters] = VIT_DEFAULTS["output_act"] dropout: Optional[NonNegativeFloat] = VIT_DEFAULTS["dropout"] @model_validator(mode="after") def make_checks(self): _, *img_size = self.in_shape self.__dict__["patch_size"] = ensure_tuple( self.patch_size, dim=len(img_size), name="patch_size" ) self._check_patch_size(self.patch_size, img_size) self._check_embedding_dim(self.embedding_dim, self.num_heads) self._check_pos_embedding( self.pos_embed_type, len(img_size), self.embedding_dim ) return self @staticmethod def _check_pos_embedding( pos_embed_type: Optional[PosEmbedType], spatial_dims: int, embedding_dim: int, ) -> Optional[nn.Parameter]: """ Checks the type of positional embedding. """ if pos_embed_type == PosEmbedType.SINCOS: if spatial_dims != 2 and spatial_dims != 3: raise ValueError( f"{spatial_dims}D sincos position embedding not implemented" ) elif spatial_dims == 2 and embedding_dim % 4: raise ValueError( f"embedding_dim must be divisible by 4 for 2D sincos position embedding. Got embedding_dim={embedding_dim}" ) elif spatial_dims == 3 and embedding_dim % 6: raise ValueError( f"embedding_dim must be divisible by 6 for 3D sincos position embedding. Got embedding_dim={embedding_dim}" ) @staticmethod def _check_embedding_dim(embedding_dim: int, num_heads: int) -> None: """ Checks consistency between embedding dimension and number of heads. """ if embedding_dim % num_heads != 0: raise ValueError( f"embedding_dim should be divisible by num_heads. Got embedding_dim={embedding_dim} " f" and num_heads={num_heads}" ) @staticmethod def _check_patch_size( patch_size: Tuple[int, ...], img_size: Tuple[int, ...] ) -> None: """ Checks consistency between image size and patch size. """ for i, p in zip(img_size, patch_size): if i % p != 0: raise ValueError( f"img_size should be divisible by patch_size. Got img_size={img_size} " f" and patch_size={patch_size}" ) @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return ViT
[docs] class ViTB16Config(NetworkConfig): """ Config class for :py:class:`clinicadl.networks.nn.ViTB16`. """ num_outputs: Optional[PositiveInt] output_act: Optional[ActivationParameters] = VIT_B_16_DEFAULTS["output_act"] pretrained: bool = VIT_B_16_DEFAULTS["pretrained"] @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return ViTB16
[docs] class ViTB32Config(NetworkConfig): """ Config class for :py:class:`clinicadl.networks.nn.ViTB32`. """ num_outputs: Optional[PositiveInt] output_act: Optional[ActivationParameters] = VIT_B_32_DEFAULTS["output_act"] pretrained: bool = VIT_B_32_DEFAULTS["pretrained"] @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return ViTB32
[docs] class ViTL16Config(NetworkConfig): """ Config class for :py:class:`clinicadl.networks.nn.ViTL16`. """ num_outputs: Optional[PositiveInt] output_act: Optional[ActivationParameters] = VIT_L_16_DEFAULTS["output_act"] pretrained: bool = VIT_L_16_DEFAULTS["pretrained"] @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return ViTL16
[docs] class ViTL32Config(NetworkConfig): """ Config class for :py:class:`clinicadl.networks.nn.ViTL32`. """ num_outputs: Optional[PositiveInt] output_act: Optional[ActivationParameters] = VIT_L_32_DEFAULTS["output_act"] pretrained: bool = VIT_L_32_DEFAULTS["pretrained"] @classmethod def _get_class(cls) -> type[nn.Module]: """Returns the network associated to this config class.""" return ViTL32
def _state_dict_adapter(state_dict: Mapping[str, Any]) -> Mapping[str, Any]: """ A mapping between torchvision's layer names and ours. """ state_dict = {k: v for k, v in state_dict.items() if "heads" not in k} mappings = [ ("ln_", "norm"), ("ln", "norm"), (r"encoder_layer_(\d+)", r"\1"), ] for key in list(state_dict.keys()): new_key = key for transform in mappings: new_key = re.sub(transform[0], transform[1], new_key) state_dict[new_key] = state_dict.pop(key) return state_dict