from __future__ import annotations
from copy import copy
from logging import getLogger
from typing import TYPE_CHECKING, Any, Mapping, Optional
import pandas as pd
import torch
from pydantic import Field, field_validator, model_validator
from torch.optim.lr_scheduler import LRScheduler, ReduceLROnPlateau
from typing_extensions import Self
from clinicadl.metrics.handler import MetricsHandler
from clinicadl.optim.lr_schedulers.config import (
LRSchedulerConfig,
LRSchedulerType,
)
from clinicadl.optim.lr_schedulers.factory import get_lr_scheduler_from_dict
from clinicadl.optim.lr_schedulers.types import LRSchedulerOrConfig
from clinicadl.utils.config import ObjectConfig, ObjectOrConfig
from clinicadl.utils.dictionary.suffixes import TSV
from clinicadl.utils.dictionary.words import BATCH, EPOCH, NAME, OPTIMIZER
from clinicadl.utils.objects import HasConfig
from ..base import Callback
if TYPE_CHECKING:
from clinicadl.io.maps import Maps
from clinicadl.train import TrainerState
logger = getLogger(__name__)
SCHEDULERS = "schedulers"
LRS = "lrs"
class LRSchedulerCallbackConfig(ObjectConfig["LRSchedulerCallback"]):
"""Config class for ``LRSchedulerCallback``."""
scheduler: ObjectOrConfig[LRScheduler, LRSchedulerConfig] = Field(
json_schema_extra={
"reader": ObjectOrConfig.build_reader(get_lr_scheduler_from_dict)
}
)
optimizer_name: str
scheduler_type: Optional[LRSchedulerType]
metric_name: Optional[str]
@field_validator("scheduler", mode="before")
@classmethod
def _handle_any_value(cls, v: Any) -> ObjectOrConfig:
"""
Converts a value to a ObjectOrConfig.
"""
return ObjectOrConfig.from_value(v)
@model_validator(mode="after")
def _validate_scheduler_type(self) -> Self:
"""Checks that 'scheduler_type' is passed if necessary, otherwise retrieves it."""
if isinstance(self.scheduler.value, LRScheduler) and not self.scheduler_type:
raise ValueError(
"If you pass directly your own LRScheduler, you must specify the type of scheduler via 'scheduler_type'."
)
elif isinstance(self.scheduler.value, LRSchedulerConfig):
self.__dict__["scheduler_type"] = self.scheduler.value.scheduler_type()
if (self.scheduler_type == LRSchedulerType.METRIC) and not self.metric_name:
raise ValueError(
f"If scheduler_type='{LRSchedulerType.METRIC.value}', you must "
"pass the name of the validation metric via 'metric_name'."
)
return self
@classmethod
def _get_class(cls):
return LRSchedulerCallback
[docs]
class LRSchedulerCallback(Callback, HasConfig[LRSchedulerConfig]):
"""
Learning Rate Scheduler to adjust the learning rate during optimization.
Parameters
----------
scheduler : Union[torch.optim.lr_scheduler.LRScheduler, LRSchedulerConfig]
The learning rate scheduler passed as a raw :py:class:`torch.optim.lr_scheduler.LRScheduler` or via
a :py:class:`LRSchedulerConfig <clinicadl.optim.lr_schedulers.config>`.
optimizer_name : str, default="optimizer"
The optimizer whose learning rate have to be scheduled. It must be a name of one of the optimizers
returned by the :py:meth:`Model.build_optimizers <clinicadl.models.Model.build_optimizers>`.
scheduler_type : Optional[str | LRSchedulerMode], default=None
The type of LR scheduler, among:
- ``"epoch-based"``: learning rate is updated at the end of the epoch (e.g. :py:class:`~torch.optim.lr_scheduler.LinearLR`);
- ``"metric-based"``: learning rate is updated after evaluation according
to a validation metric (e.g. :py:class:`~torch.optim.lr_scheduler.ReduceLROnPlateau`);
- ``"step-based"``: learning rate is updated after each optimization step
(e.g. :py:class:`~torch.optim.lr_scheduler.OneCycleLR`).
**Mandatory if a raw LRScheduler is passed** to ``scheduler``. It will be ignore if a
:py:class:`LRSchedulerConfig <clinicadl.optim.lr_schedulers.config>` is passed.
metric_name : Optional[str], default=None
If ``scheduler_type="metric-based"``, it is the name of the metric to monitor.
Examples
--------
.. code-block::
from clinicadl.callbacks import LRSchedulerCallback
from clinicadl.metrics.config import MSEMetricConfig, LossMetricConfig
from clinicadl.optim.lr_schedulers.config import ReduceLROnPlateauConfig
from clinicadl.models import SupervisedModel
model = SupervisedModel(...) # there is only one optimizer named 'optimizer'
trainer = Trainer(
metrics={"loss": LossMetricConfig(), "mse": MSEMetricConfig()},
callbacks=[
LRSchedulerCallback(
scheduler=ReduceLROnPlateauConfig(mode="min"),
metric_name="mse",
optimizer_name="optimizer",
)
],
...
)
"""
_config_type = LRSchedulerCallbackConfig
def __init__(
self,
scheduler: LRSchedulerOrConfig,
optimizer_name: str = OPTIMIZER,
scheduler_type: Optional[str | LRSchedulerType] = None,
metric_name: Optional[str] = None,
):
self.config: LRSchedulerCallbackConfig = self._config_type(
scheduler=scheduler,
optimizer_name=optimizer_name,
scheduler_type=scheduler_type,
metric_name=metric_name,
)
self.scheduler_config: Optional[LRSchedulerConfig] = None
self.scheduler: Optional[LRScheduler] = None
self._initial_state: Optional[dict] = None
self._current_lrs: Optional[list[float]] = None
self._lrs: dict[tuple[int, int], list[float]] = {}
scheduler = self.config.scheduler.value
if isinstance(scheduler, LRScheduler):
self.scheduler = scheduler
self._initial_state = self.scheduler.state_dict()
elif isinstance(scheduler, LRSchedulerConfig):
self.scheduler_config = scheduler
# pylint: disable=arguments-differ, unused-argument
[docs]
def on_train_start(
self,
*,
optimizers: dict[str, torch.optim.Optimizer],
metrics: MetricsHandler,
**kwargs,
) -> None:
try:
optimizer = optimizers[self.config.optimizer_name]
except KeyError as exc:
raise KeyError(
f"In {type(self).__name__}, optimizer_name='{self.config.optimizer_name}' but there is no such optimizer (built with 'build_optimizers' method of your clinicadl.model.Model). "
f"Optimizers are: {list(optimizers.keys())}"
) from exc
if self.scheduler_config:
self.scheduler = self.scheduler_config.get_object(optimizer)
else:
if optimizer is not self.scheduler.optimizer:
raise ValueError(
f"The optimizer associated to the LR scheduler {type(self.scheduler).__name__} is not the same as "
f"'{self.config.optimizer_name}' (returned by 'build_optimizers' method of your clinicadl.model.Model)."
)
self.scheduler.load_state_dict(self._initial_state)
self._validate_metric_scheduler(metrics)
self._lrs = {}
self._current_lrs = self.scheduler.get_last_lr()
[docs]
def on_resume(
self, *, optimizers: dict[str, torch.optim.Optimizer], **kwargs
) -> None:
optimizer = optimizers[self.config.optimizer_name]
if self.scheduler_config:
self.scheduler = self.scheduler_config.get_object(optimizer)
[docs]
def on_optimization_step_end(self, state: TrainerState, **kwargs) -> None:
if self.config.scheduler_type == LRSchedulerType.STEP:
self._scheduler_step(state=state)
[docs]
def on_validation_end(
self, *, state: TrainerState, metrics: MetricsHandler, **kwargs
) -> None:
if self.config.scheduler_type == LRSchedulerType.METRIC:
val_metric = metrics.get_metric_value(
metric=self.config.metric_name,
epoch=state.current_epoch,
)
self._scheduler_step(val_metric, state=state)
[docs]
def on_epoch_end(self, state: TrainerState, **kwargs) -> None:
if self.config.scheduler_type == LRSchedulerType.EPOCH:
self._scheduler_step(state=state)
[docs]
def on_train_end(
self,
*,
maps: Maps,
state: TrainerState,
**kwargs,
) -> None:
self._lrs[(state.current_epoch, state.current_train_batch)] = self._current_lrs
df = pd.DataFrame.from_dict(self._lrs, orient="index").rename(
columns={0: "lr"}
) # named 0 if no param group
df.index = pd.MultiIndex.from_tuples(df.index)
df = df.reindex(
pd.MultiIndex.from_product(
[
range(1, state.current_epoch + 1),
range(1, state.num_train_batches + 1),
],
names=[EPOCH, BATCH],
),
)
if param_groups := self._get_param_groups(self.scheduler.optimizer):
df.columns = param_groups
df = df.bfill()
maps.training.splits[state.split_idx].logs.learning_rates.mkdir(
exist_ok=True, parents=True
)
maps.save_file(
df.reset_index(),
path=(
maps.training.splits[state.split_idx].logs.learning_rates
/ self.config.optimizer_name
).with_suffix(TSV),
overwrite=True,
)
[docs]
def state_dict(self) -> Mapping[str, Any]:
return {SCHEDULERS: self.scheduler.state_dict(), LRS: copy(self._lrs)}
[docs]
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
self.scheduler.load_state_dict(state_dict[SCHEDULERS])
self._lrs = state_dict[LRS]
self._current_lrs = self.scheduler.get_last_lr()
def _scheduler_step(self, *args, state: TrainerState) -> None:
self.scheduler.step(*args)
self._lrs[(state.current_epoch, state.current_train_batch)] = self._current_lrs
self._current_lrs = self.scheduler.get_last_lr()
def _validate_metric_scheduler(self, metrics: MetricsHandler) -> None:
"""
Checks consistency if is a metric-based.
"""
if self.config.scheduler_type == LRSchedulerType.METRIC:
metrics.check_metric_name(self.config.metric_name)
opt = metrics.metrics[self.config.metric_name].optimum
if isinstance(self.scheduler, ReduceLROnPlateau):
if self.scheduler.mode != opt:
logger.warning(
"Found mode='%s' in ReduceLROnPlateau, but found optimum='%s' in '%s'. "
"This may be an error.",
self.scheduler.mode,
opt.value,
self.config.metric_name,
)
@staticmethod
def _get_param_groups(optimizer: torch.optim.Optimizer) -> list[str]:
"""Retrieves all parameter groups."""
names = []
for group in optimizer.param_groups:
try:
names.append(group[NAME])
except KeyError:
return []
if len(names) == 1:
return []
return names