clinicadl.optim.optimizers.config.RMSpropConfig

clinicadl.optim.optimizers.config.RMSpropConfig[source]

Config class for torch.optim.RMSprop.

The parameters of the optimizer can here be passed via dictionaries, whose keys are parameter groups and values are the values to apply to these groups. Such a dictionary must always contain the key "ELSE" that specifies the value for the rest of the parameters.

freeze can be used to freeze some weights of the neural network.

Examples

>>> from clinicadl.networks.nn import CNN
>>> network = CNN(
        in_shape=(1, 16, 16, 16),
        num_outputs=1,
        conv_args={"channels": [2, 4]},
    )
>>> network
CNN(
    (convolutions): ConvEncoder(
        (layer0): Convolution(
            (conv): Conv3d(1, 2, kernel_size=(3, 3, 3), stride=(1, 1, 1))
            (adn): ADN(
                (N): InstanceNorm3d(2, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
                (A): PReLU(num_parameters=1)
            )
        )
        (layer1): Convolution(
            (conv): Conv3d(2, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1))
        )
    )
    (mlp): MLP(
        (flatten): Flatten(start_dim=1, end_dim=-1)
        (output): Sequential(
            (linear): Linear(in_features=6912, out_features=1, bias=True)
        )
    )
)
>>> from clinicadl.optim.optimizers.config import AdamConfig
>>> optimizer_config = AdamConfig(
        freeze="mlp.output", lr={"convolutions.layer0": 1e-2, "ELSE": 1e-3}
    )
>>> optimizer = optimizer_config.get_object(network)
>>> len(optimizer.param_groups)
2   # 2 groups of parameters: 'convolutions.layer0' and the rest of the network
>>> next(net.mlp.output.parameters()).requires_grad
False   # 'mlp.output' is frozen
parameter lr: Union[PositiveFloat, Dict[str, PositiveFloat]] = 0.01
parameter freeze: Union[str, List[str], None] = None
parameter alpha: Union[NonNegativeFloat, Dict[str, NonNegativeFloat]] = 0.99
parameter eps: Union[NonNegativeFloat, Dict[str, NonNegativeFloat]] = 1e-08
parameter weight_decay: Union[NonNegativeFloat, Dict[str, NonNegativeFloat]] = 0
parameter momentum: Union[NonNegativeFloat, Dict[str, NonNegativeFloat]] = 0
parameter centered: Union[bool, Dict[str, bool]] = False
parameter capturable: Union[bool, Dict[str, bool]] = False
parameter foreach: Union[bool, None, Dict[str, Optional[bool]]] = None
parameter maximize: Union[bool, Dict[str, bool]] = False
parameter differentiable: Union[bool, Dict[str, bool]] = False