clinicadl.callbacks.LRSchedulerCallback

class clinicadl.callbacks.LRSchedulerCallback(scheduler: LRScheduler | LRSchedulerConfig, optimizer_name: str = 'optimizer', scheduler_type: str | LRSchedulerType | None = None, metric_name: str | None = None)[source]

Learning Rate Scheduler to adjust the learning rate during optimization.

Parameters:
  • scheduler (Union[torch.optim.lr_scheduler.LRScheduler, LRSchedulerConfig]) – The learning rate scheduler passed as a raw torch.optim.lr_scheduler.LRScheduler or via a LRSchedulerConfig.

  • optimizer_name (str, default="optimizer") – The optimizer whose learning rate have to be scheduled. It must be a name of one of the optimizers returned by the Model.build_optimizers.

  • scheduler_type (Optional[str | LRSchedulerMode], default=None) –

    The type of LR scheduler, among:

    • "epoch-based": learning rate is updated at the end of the epoch (e.g. LinearLR);

    • "metric-based": learning rate is updated after evaluation according to a validation metric (e.g. ReduceLROnPlateau);

    • "step-based": learning rate is updated after each optimization step (e.g. OneCycleLR).

    Mandatory if a raw LRScheduler is passed to scheduler. It will be ignore if a LRSchedulerConfig is passed.

  • metric_name (Optional[str], default=None) – If scheduler_type="metric-based", it is the name of the metric to monitor.

Examples

from clinicadl.callbacks import LRSchedulerCallback
from clinicadl.metrics.config import MSEMetricConfig, LossMetricConfig
from clinicadl.optim.lr_schedulers.config import ReduceLROnPlateauConfig
from clinicadl.models import SupervisedModel

model = SupervisedModel(...)    # there is only one optimizer named 'optimizer'

trainer = Trainer(
    metrics={"loss": LossMetricConfig(), "mse": MSEMetricConfig()},
    callbacks=[
        LRSchedulerCallback(
            scheduler=ReduceLROnPlateauConfig(mode="min"),
            metric_name="mse",
            optimizer_name="optimizer",
        )
    ],
    ...
)
on_train_start(*, optimizers: dict[str, Optimizer], metrics: MetricsHandler, **kwargs) None[source]

Called once at the beginning of Trainer.train if resume=False.

If resuming a training, on_resume() will be called instead.

Parameters:
on_resume(*, optimizers: dict[str, Optimizer], **kwargs) None[source]

Called once when Trainer.train is resuming a training.

More precisely, this method will be called just before loading the checkpoints.

Parameters:
on_optimization_step_end(state: TrainerState, **kwargs) None[source]

Called every time Model.optimization_step has just been called in Trainer.train.

Parameters:
on_validation_end(*, state: TrainerState, metrics: MetricsHandler, **kwargs) None[source]

Called at the end of every validation loop in Trainer.train.

Not to be confused with on_validate_end().

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • metrics (MetricsHandler) – The validation metrics computed.

on_epoch_end(state: TrainerState, **kwargs) None[source]

Called at the end of an epoch in Trainer.train.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

on_train_end(*, maps: Maps, state: TrainerState, **kwargs) None[source]

Called once at the end of Trainer.train.

Parameters:
  • model (Model) – The model associated to the Trainer.

  • maps (Maps) – The MAPS associated to the Trainer.

  • state (TrainerState) – The current state of the Trainer.

state_dict() Mapping[str, Any][source]

To get a checkpoint of the current state of the callback.

Returns:

Mapping[str, Any] – The current state in a dict.

load_state_dict(state_dict: Mapping[str, Any]) None[source]

Sets to callbacks to a given state.

Parameters:

state_dict (Mapping[str, Any]) – The desired state of the Callback, as returned by state_dict().