clinicadl.optim.optimizers.config.AdamConfig

clinicadl.optim.optimizers.config.AdamConfig[source]

Config class for torch.optim.Adam.

The parameters of the optimizer can here be passed via dictionaries, whose keys are parameter groups and values are the values to apply to these groups. Such a dictionary must always contain the key "ELSE" that specifies the value for the rest of the parameters.

freeze can be used to freeze some weights of the neural network.

Examples

>>> from clinicadl.networks.nn import CNN
>>> network = CNN(
        in_shape=(1, 16, 16, 16),
        num_outputs=1,
        conv_args={"channels": [2, 4]},
    )
>>> network
CNN(
    (convolutions): ConvEncoder(
        (layer0): Convolution(
            (conv): Conv3d(1, 2, kernel_size=(3, 3, 3), stride=(1, 1, 1))
            (adn): ADN(
                (N): InstanceNorm3d(2, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
                (A): PReLU(num_parameters=1)
            )
        )
        (layer1): Convolution(
            (conv): Conv3d(2, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1))
        )
    )
    (mlp): MLP(
        (flatten): Flatten(start_dim=1, end_dim=-1)
        (output): Sequential(
            (linear): Linear(in_features=6912, out_features=1, bias=True)
        )
    )
)
>>> from clinicadl.optim.optimizers.config import AdamConfig
>>> optimizer_config = AdamConfig(
        freeze="mlp.output", lr={"convolutions.layer0": 1e-2, "ELSE": 1e-3}
    )
>>> optimizer = optimizer_config.get_object(network)
>>> len(optimizer.param_groups)
2   # 2 groups of parameters: 'convolutions.layer0' and the rest of the network
>>> next(net.mlp.output.parameters()).requires_grad
False   # 'mlp.output' is frozen
parameter lr: Union[PositiveFloat, Dict[str, PositiveFloat]] = 0.001
parameter freeze: Union[str, List[str], None] = None
parameter betas: Union[Tuple[NonNegativeFloat, NonNegativeFloat], Dict[str, Tuple[NonNegativeFloat, NonNegativeFloat]]] = (0.9, 0.999)
parameter eps: Union[NonNegativeFloat, Dict[str, NonNegativeFloat]] = 1e-08
parameter weight_decay: Union[NonNegativeFloat, Dict[str, NonNegativeFloat]] = 0
parameter amsgrad: Union[bool, Dict[str, bool]] = False
parameter foreach: Union[bool, None, Dict[str, Optional[bool]]] = None
parameter maximize: Union[bool, Dict[str, bool]] = False
parameter capturable: Union[bool, Dict[str, bool]] = False
parameter differentiable: Union[bool, Dict[str, bool]] = False
parameter fused: Union[bool, None, Dict[str, Optional[bool]]] = None