clinicadl.callbacks.LRSchedulerCallback¶
- class clinicadl.callbacks.LRSchedulerCallback(scheduler: LRScheduler | LRSchedulerConfig, optimizer_name: str = 'optimizer', scheduler_type: str | LRSchedulerType | None = None, metric_name: str | None = None)[source]¶
Learning Rate Scheduler to adjust the learning rate during optimization.
- Parameters:
scheduler (Union[torch.optim.lr_scheduler.LRScheduler, LRSchedulerConfig]) – The learning rate scheduler passed as a raw
torch.optim.lr_scheduler.LRScheduleror via aLRSchedulerConfig.optimizer_name (str, default="optimizer") – The optimizer whose learning rate have to be scheduled. It must be a name of one of the optimizers returned by the
Model.build_optimizers.scheduler_type (Optional[str | LRSchedulerMode], default=None) –
The type of LR scheduler, among:
"epoch-based": learning rate is updated at the end of the epoch (e.g.LinearLR);"metric-based": learning rate is updated after evaluation according to a validation metric (e.g.ReduceLROnPlateau);"step-based": learning rate is updated after each optimization step (e.g.OneCycleLR).
Mandatory if a raw LRScheduler is passed to
scheduler. It will be ignore if aLRSchedulerConfigis passed.metric_name (Optional[str], default=None) – If
scheduler_type="metric-based", it is the name of the metric to monitor.
Examples
from clinicadl.callbacks import LRSchedulerCallback from clinicadl.metrics.config import MSEMetricConfig, LossMetricConfig from clinicadl.optim.lr_schedulers.config import ReduceLROnPlateauConfig from clinicadl.models import SupervisedModel model = SupervisedModel(...) # there is only one optimizer named 'optimizer' trainer = Trainer( metrics={"loss": LossMetricConfig(), "mse": MSEMetricConfig()}, callbacks=[ LRSchedulerCallback( scheduler=ReduceLROnPlateauConfig(mode="min"), metric_name="mse", optimizer_name="optimizer", ) ], ... )
- on_train_start(*, optimizers: dict[str, Optimizer], metrics: MetricsHandler, **kwargs) None[source]¶
Called once at the beginning of
Trainer.trainifresume=False.If resuming a training,
on_resume()will be called instead.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.split (Split) – The
clinicadl.split.Spliton which training is performed.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
optimization (OptimizationConfig) – The optimization specifications of the training phase.
metrics (MetricsHandler) – The validation metrics to compute.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.computational (ComputationalConfig) – The
clinicadl.train.ComputationalConfigdefining the computational specifications of the training phase.
- on_resume(*, optimizers: dict[str, Optimizer], **kwargs) None[source]¶
Called once when
Trainer.trainis resuming a training.More precisely, this method will be called just before loading the checkpoints.
- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.split (Split) – The
clinicadl.split.Spliton which training is performed.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
optimization (OptimizationConfig) – The optimization specifications of the training phase.
metrics (MetricsHandler) – The validation metrics to compute.
callbacks (CallbacksHandler) – The callbacks passed to the
Trainer.computational (ComputationalConfig) – The
clinicadl.train.ComputationalConfigdefining the computational specifications of the training phase.
- on_optimization_step_end(state: TrainerState, **kwargs) None[source]¶
Called every time
Model.optimization_stephas just been called inTrainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.optimizers (dict[str, torch.optim.Optimizer]) – The current
torch.optim.Optimizer, as returned by byModel.backward_step.grad_scaler (torch.amp.GradScaler) – The torch.amp.GradScaler used to scale gradients.
- on_validation_end(*, state: TrainerState, metrics: MetricsHandler, **kwargs) None[source]¶
Called at the end of every validation loop in
Trainer.train.Not to be confused with
on_validate_end().- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.metrics (MetricsHandler) – The validation metrics computed.
- on_epoch_end(state: TrainerState, **kwargs) None[source]¶
Called at the end of an epoch in
Trainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.
- on_train_end(*, maps: Maps, state: TrainerState, **kwargs) None[source]¶
Called once at the end of
Trainer.train.- Parameters:
model (Model) – The model associated to the
Trainer.state (TrainerState) – The current state of the
Trainer.