Source code for clinicadl.callbacks.implemented.training_checkpoint

from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING

import pandas as pd
import torch
from pydantic import NonNegativeInt

from clinicadl.utils.config import ObjectConfig
from clinicadl.utils.dictionary.suffixes import PT
from clinicadl.utils.dictionary.utils import TSV_SEP
from clinicadl.utils.enum import TrainerCall
from clinicadl.utils.names import camel_to_snake
from clinicadl.utils.objects import HasConfig

from ..base import Callback

if TYPE_CHECKING:
    from clinicadl.io.maps import Maps
    from clinicadl.io.maps.training.splits.tmp import EpochTmpDir
    from clinicadl.metrics import MetricsHandler
    from clinicadl.models import Model
    from clinicadl.train import TrainerState

    from ..handler import CallbacksHandler

logger = logging.getLogger(__name__)


class TrainingCheckpointCallbackConfig(ObjectConfig["TrainingCheckpointCallback"]):
    """Config class for ``TrainingCheckpointCallback``."""

    every_n_epochs: NonNegativeInt
    enabled: bool

    @classmethod
    def _get_class(cls):
        return TrainingCheckpointCallback


[docs] class TrainingCheckpointCallback(Callback, HasConfig[TrainingCheckpointCallbackConfig]): """ To save checkpoints of a training phase. The user can then resume a training from the last saved checkpoint when calling :py:meth:`Trainer.resume <clinicadl.train.Trainer.resume>`. The checkpoints will be **deleted when the training is completed**. To save permanently checkpoints of your neural network, use instead :py:class:`~clinicadl.callbacks.ModelCheckpointCallback`. Parameters ---------- every_n_epochs : int, default=10 Interval (in epochs) for saving checkpoints. enabled : bool, default=True Whether to activate checkpointing. """ _config_type = TrainingCheckpointCallbackConfig def __init__(self, every_n_epochs: int = 10, enabled: bool = True): self.config = self._config_type(every_n_epochs=every_n_epochs, enabled=enabled) self._metrics = None self._callbacks = None self._optimizers = None self._scaler = None
[docs] def on_train_start( self, *, metrics: MetricsHandler, callbacks: CallbacksHandler, optimizers: dict[str, torch.optim.Optimizer], grad_scaler: torch.amp.GradScaler, **kwargs, ) -> None: self._metrics = metrics self._callbacks = callbacks self._optimizers = optimizers self._scaler = grad_scaler
[docs] def on_exception(self, *, maps: Maps, state: TrainerState, **kwargs) -> None: if not state.called == TrainerCall.TRAIN or not self.config.enabled: return try: last_saved_epoch = self._get_last_saved_epoch(maps, state.split_idx) except FileNotFoundError: return logger.error("Last checkpoint at the end of epoch %d", last_saved_epoch)
[docs] def on_resume( self, *, model: Model, maps: Maps, state: TrainerState, metrics: MetricsHandler, callbacks: CallbacksHandler, optimizers: dict[str, torch.optim.Optimizer], grad_scaler: torch.amp.GradScaler, **kwargs, ) -> None: self.on_train_start( metrics=metrics, callbacks=callbacks, optimizers=optimizers, grad_scaler=grad_scaler, ) tmp_dir = maps.training.splits[state.split_idx].tmp last_saved_epoch = self._get_last_saved_epoch(maps, split_idx=state.split_idx) logger.info("Loading checkpoints from epoch %d", last_saved_epoch) chkpt_dir = tmp_dir.epochs[last_saved_epoch] state.load_state_dict(maps.open_file(chkpt_dir.state_json)) model.load_state_dict(maps.open_file(chkpt_dir.model_pt)) opt_state_dicts = maps.open_file(chkpt_dir.optimizer_pt) for opt_name, opt in optimizers.items(): opt.load_state_dict(opt_state_dicts[opt_name]) grad_scaler.load_state_dict(maps.open_file(chkpt_dir.scaler_pt)) computed_metrics = pd.read_csv( chkpt_dir.validation_metrics.aggregated_tsv, sep=TSV_SEP ).columns metrics.remove_metrics(set(metrics.metrics.keys()).difference(computed_metrics)) metrics.load( chkpt_dir.validation_metrics.aggregated_tsv, details_path=chkpt_dir.validation_metrics.details_tsv, ) self._load_callbacks(callbacks, chkpt_dir=chkpt_dir, maps=maps)
[docs] def on_epoch_end(self, *, model: Model, maps: Maps, state: TrainerState) -> None: if ( not self.config.enabled or not state.current_epoch % self.config.every_n_epochs == 0 ): return maps.training.splits[state.split_idx].tmp.create_epoch( state.current_epoch, overwrite=True ) chkpt_dir = maps.training.splits[state.split_idx].tmp.epochs[ state.current_epoch ] state.to_json(chkpt_dir.state_json) maps.save_file(model.state_dict(), chkpt_dir.model_pt) maps.save_file( {opt_name: opt.state_dict() for opt_name, opt in self._optimizers.items()}, chkpt_dir.optimizer_pt, ) maps.save_file(self._scaler.state_dict(), chkpt_dir.scaler_pt) self._metrics.save( chkpt_dir.validation_metrics.aggregated_tsv, details_path=chkpt_dir.validation_metrics.details_tsv, ) self._save_callbacks(chkpt_dir, maps=maps) maps.training.splits[state.split_idx].tmp.clear( except_epoch=state.current_epoch ) logger.debug("Training checkpoint saved after epoch %d", state.current_epoch)
[docs] def on_train_end( self, *, maps: Maps, state: TrainerState, **kwargs, ) -> None: if maps.training.splits[state.split_idx].tmp.path.exists(): maps.training.splits[state.split_idx].tmp.clear()
def _save_callbacks(self, chkpt_dir: EpochTmpDir, maps: Maps) -> None: """ Saves the callback checkpoints. """ for callback, file_name in zip( self._callbacks.all_callbacks, self._get_callback_file_names(list(self._callbacks.all_callbacks)), ): maps.save_file(callback.state_dict(), chkpt_dir.callbacks / file_name) @classmethod def _load_callbacks( cls, callbacks: CallbacksHandler, chkpt_dir: EpochTmpDir, maps: Maps ) -> None: """ Loads the callback checkpoints. """ for callback, file_name in zip( callbacks.all_callbacks, cls._get_callback_file_names(list(callbacks.all_callbacks)), ): callback.load_state_dict(maps.open_file(chkpt_dir.callbacks / file_name)) @staticmethod def _get_callback_file_names(callbacks: list[Callback]) -> list[Path]: """ Gives a snake case file name for each callback. """ names = [camel_to_snake(type(callback).__name__) for callback in callbacks] cnt = {name: 0 for name in set(names)} for i, name in enumerate(names): names[i] += f"_{cnt[name]}" if cnt[name] else "" cnt[name] += 1 return map(lambda x: Path(x).with_suffix(PT), names) @staticmethod def _get_last_saved_epoch(maps: Maps, split_idx: int) -> int: """ Gets the last epoch saved in the checkpoint directory. """ tmp_dir = maps.training.splits[split_idx].tmp tmp_dir.read() try: return int(sorted(tmp_dir.epochs_list)[-1]) except IndexError as e: raise FileNotFoundError( f"No training checkpoint found in {str(tmp_dir.path)}" ) from e