Source code for clinicadl.callbacks.implemented.lr_scheduler

from __future__ import annotations

from copy import copy
from logging import getLogger
from typing import TYPE_CHECKING, Any, Mapping, Optional

import pandas as pd
import torch
from pydantic import Field, field_validator, model_validator
from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau
from typing_extensions import Self

from clinicadl.metrics.handler import MetricsHandler
from clinicadl.optim.lr_schedulers.config import (
    LRSchedulerConfig,
    LRSchedulerType,
)
from clinicadl.optim.lr_schedulers.factory import get_lr_scheduler_from_dict
from clinicadl.optim.lr_schedulers.types import LRSchedulerOrConfig
from clinicadl.utils.config import ObjectConfig, ObjectOrConfig
from clinicadl.utils.dictionary.suffixes import TSV
from clinicadl.utils.dictionary.words import BATCH, EPOCH, NAME, OPTIMIZER
from clinicadl.utils.objects import HasConfig

from ..base import Callback

if TYPE_CHECKING:
    from clinicadl.io.maps import Maps
    from clinicadl.train import TrainerState

logger = getLogger(__name__)

SCHEDULERS = "schedulers"
LRS = "lrs"


class LRSchedulerCallbackConfig(ObjectConfig["LRSchedulerCallback"]):
    """Config class for ``LRSchedulerCallback``."""

    scheduler: ObjectOrConfig[LRScheduler, LRSchedulerConfig] = Field(
        json_schema_extra={
            "reader": ObjectOrConfig.build_reader(get_lr_scheduler_from_dict)
        }
    )
    optimizer_name: str
    scheduler_type: Optional[LRSchedulerType]
    metric_name: Optional[str]

    @field_validator("scheduler", mode="before")
    @classmethod
    def _handle_any_value(cls, v: Any) -> ObjectOrConfig:
        """
        Converts a value to a ObjectOrConfig.
        """
        return ObjectOrConfig.from_value(v)

    @model_validator(mode="after")
    def _validate_scheduler_type(self) -> Self:
        """Checks that 'scheduler_type' is passed if necessary, otherwise retrieves it."""
        if isinstance(self.scheduler.value, LRScheduler) and not self.scheduler_type:
            raise ValueError(
                "If you pass directly your own LRScheduler, you must specify the type of scheduler via 'scheduler_type'."
            )
        elif isinstance(self.scheduler.value, LRSchedulerConfig):
            self.__dict__["scheduler_type"] = self.scheduler.value.scheduler_type()

        if (self.scheduler_type == LRSchedulerType.METRIC) and not self.metric_name:
            raise ValueError(
                f"If scheduler_type='{LRSchedulerType.METRIC.value}', you must "
                "pass the name of the validation metric via 'metric_name'."
            )

        return self

    @classmethod
    def _get_class(cls):
        return LRSchedulerCallback


[docs] class LRSchedulerCallback(Callback, HasConfig[LRSchedulerConfig]): """ Learning Rate Scheduler to adjust the learning rate during optimization. Parameters ---------- scheduler : Union[torch.optim.lr_scheduler.LRScheduler, LRSchedulerConfig] The learning rate scheduler passed as a raw :py:class:`torch.optim.lr_scheduler.LRScheduler` or via a :py:class:`LRSchedulerConfig <clinicadl.optim.lr_schedulers.config>`. optimizer_name : str, default="optimizer" The optimizer whose learning rate have to be scheduled. It must be a name of one of the optimizers returned by the :py:meth:`Model.build_optimizers <clinicadl.models.Model.build_optimizers>`. scheduler_type : Optional[str | LRSchedulerMode], default=None The type of LR scheduler, among: - ``"epoch-based"``: learning rate is updated at the end of the epoch (e.g. :py:class:`~torch.optim.lr_scheduler.LinearLR`); - ``"metric-based"``: learning rate is updated after evaluation according to a validation metric (e.g. :py:class:`~torch.optim.lr_scheduler.ReduceLROnPlateau`); - ``"step-based"``: learning rate is updated after each optimization step (e.g. :py:class:`~torch.optim.lr_scheduler.OneCycleLR`). **Mandatory if a raw LRScheduler is passed** to ``scheduler``. It will be ignore if a :py:class:`LRSchedulerConfig <clinicadl.optim.lr_schedulers.config>` is passed. metric_name : Optional[str], default=None If ``scheduler_type="metric-based"``, it is the name of the metric to monitor. Examples -------- .. code-block:: from clinicadl.callbacks import LRSchedulerCallback from clinicadl.metrics.config import MSEMetricConfig, LossMetricConfig from clinicadl.optim.lr_schedulers.config import ReduceLROnPlateauConfig from clinicadl.models import SupervisedModel model = SupervisedModel(...) # there is only one optimizer named 'optimizer' trainer = Trainer( metrics={"loss": LossMetricConfig(), "mse": MSEMetricConfig()}, callbacks=[ LRSchedulerCallback( scheduler=ReduceLROnPlateauConfig(mode="min"), metric_name="mse", optimizer_name="optimizer", ) ], ... ) """ _config_type = LRSchedulerCallbackConfig def __init__( self, scheduler: LRSchedulerOrConfig, optimizer_name: str = OPTIMIZER, scheduler_type: Optional[str | LRSchedulerType] = None, metric_name: Optional[str] = None, ): self.config: LRSchedulerCallbackConfig = self._config_type( scheduler=scheduler, optimizer_name=optimizer_name, scheduler_type=scheduler_type, metric_name=metric_name, ) self.scheduler_config: Optional[LRSchedulerConfig] = None self.scheduler: Optional[LRScheduler] = None self._initial_state: Optional[dict] = None self._current_lrs: Optional[list[float]] = None self._lrs: dict[tuple[int, int], list[float]] = {} scheduler = self.config.scheduler.value if isinstance(scheduler, LRScheduler): self.scheduler = scheduler self._initial_state = self.scheduler.state_dict() elif isinstance(scheduler, LRSchedulerConfig): self.scheduler_config = scheduler # pylint: disable=arguments-differ, unused-argument
[docs] def on_train_start( self, *, optimizers: dict[str, torch.optim.Optimizer], metrics: MetricsHandler, **kwargs, ) -> None: try: optimizer = optimizers[self.config.optimizer_name] except KeyError as exc: raise KeyError( f"In {type(self).__name__}, optimizer_name='{self.config.optimizer_name}' but there is no such optimizer (built with 'build_optimizers' method of your clinicadl.model.Model). " f"Optimizers are: {list(optimizers.keys())}" ) from exc if self.scheduler_config: self.scheduler = self.scheduler_config.get_object(optimizer) else: if optimizer is not self.scheduler.optimizer: raise ValueError( f"The optimizer associated to the LR scheduler {type(self.scheduler).__name__} is not the same as " f"'{self.config.optimizer_name}' (returned by 'build_optimizers' method of your clinicadl.model.Model)." ) self.scheduler.load_state_dict(self._initial_state) self._validate_metric_scheduler(metrics) self._lrs = {} self._current_lrs = self.scheduler.get_last_lr()
[docs] def on_resume( self, *, optimizers: dict[str, torch.optim.Optimizer], **kwargs ) -> None: optimizer = optimizers[self.config.optimizer_name] if self.scheduler_config: self.scheduler = self.scheduler_config.get_object(optimizer)
[docs] def on_optimization_step_end(self, state: TrainerState, **kwargs) -> None: if self.config.scheduler_type == LRSchedulerType.STEP: self._scheduler_step(state=state)
[docs] def on_validation_end( self, *, state: TrainerState, metrics: MetricsHandler, **kwargs ) -> None: if self.config.scheduler_type == LRSchedulerType.METRIC: val_metric = metrics.get_metric_value( metric=self.config.metric_name, epoch=state.current_epoch, ) self._scheduler_step(val_metric, state=state)
[docs] def on_epoch_end(self, state: TrainerState, **kwargs) -> None: if self.config.scheduler_type == LRSchedulerType.EPOCH: self._scheduler_step(state=state)
[docs] def on_train_end( self, *, maps: Maps, state: TrainerState, **kwargs, ) -> None: self._lrs[(state.current_epoch, state.current_train_batch)] = self._current_lrs df = pd.DataFrame.from_dict(self._lrs, orient="index").rename( columns={0: "lr"} ) # named 0 if no param group df.index = pd.MultiIndex.from_tuples(df.index) df = df.reindex( pd.MultiIndex.from_product( [ range(1, state.current_epoch + 1), range(1, state.num_train_batches + 1), ], names=[EPOCH, BATCH], ), ) if param_groups := self._get_param_groups(self.scheduler.optimizer): df.columns = param_groups df = df.bfill() maps.training.splits[state.split_idx].logs.learning_rates.mkdir( exist_ok=True, parents=True ) maps.save_file( df.reset_index(), path=( maps.training.splits[state.split_idx].logs.learning_rates / self.config.optimizer_name ).with_suffix(TSV), overwrite=True, )
[docs] def state_dict(self) -> Mapping[str, Any]: return {SCHEDULERS: self.scheduler.state_dict(), LRS: copy(self._lrs)}
[docs] def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: self.scheduler.load_state_dict(state_dict[SCHEDULERS]) self._lrs = state_dict[LRS] self._current_lrs = self.scheduler.get_last_lr()
def _scheduler_step(self, *args, state: TrainerState) -> None: self.scheduler.step(*args) self._lrs[(state.current_epoch, state.current_train_batch)] = self._current_lrs self._current_lrs = self.scheduler.get_last_lr() def _validate_metric_scheduler(self, metrics: MetricsHandler) -> None: """ Checks consistency if is a metric-based. """ if self.config.scheduler_type == LRSchedulerType.METRIC: metrics.check_metric_name(self.config.metric_name) opt = metrics.metrics[self.config.metric_name].optimum if isinstance(self.scheduler, ReduceLROnPlateau): if self.scheduler.mode != opt: logger.warning( "Found mode='%s' in ReduceLROnPlateau, but found optimum='%s' in '%s'. " "This may be an error.", self.scheduler.mode, opt.value, self.config.metric_name, ) @staticmethod def _get_param_groups(optimizer: torch.optim.Optimizer) -> list[str]: """Retrieves all parameter groups.""" names = [] for group in optimizer.param_groups: try: names.append(group[NAME]) except KeyError: return [] if len(names) == 1: return [] return names