Source code for clinicadl.train.trainer_state

from typing import Any, Optional

from clinicadl.data.dataloader import DataLoader
from clinicadl.utils.config import ClinicaDLConfig
from clinicadl.utils.enum import TrainerCall, TrainerStage


[docs] class TrainerState(ClinicaDLConfig): """ Represents the state of a :py:class:`~clinicadl.train.Trainer`. Attributes ---------- called : Optional[TrainerCall] The method of the ``Trainer`` that has been called. One of ``"train"``, ``"validate"`` or ``"test"``. ``None`` if no method has been called so far. stage : Optional[TrainerStage] Current action performed by the ``Trainer``. One of ``"training"`` or ``"evaluation"``. should_stop : bool Whether the training should be stopped at the end of the current epoch during :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. current_train_batch : int Index of the current training batch in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. num_train_batches : int Total number of training batches in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. current_val_batch : int Index of the current validation batch in :py:meth:`Trainer.validate <clinicadl.train.Trainer.validate>` or :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. num_val_batches : int Total number of validation batches in :py:meth:`Trainer.validate <clinicadl.train.Trainer.validate>` or :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. current_test_batch : int Index of the current test batch :py:meth:`Trainer.test <clinicadl.train.Trainer.test>`. num_test_batches : int Total number of test batches in :py:meth:`Trainer.test <clinicadl.train.Trainer.test>`. current_epoch : int Index of the current epoch in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. num_epochs : int Total number of epochs in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. optim_step : int The total number of optimization steps performed so far in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. split_idx : Optional[int] Index of the split on which training/validation is currently performed in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>` or :py:meth:`Trainer.validate <clinicadl.train.Trainer.validate>`. """ called: Optional[TrainerCall] = None stage: Optional[TrainerStage] = None should_stop: bool = False current_train_batch: int = 0 num_train_batches: int = 0 current_val_batch: int = 0 num_val_batches: int = 0 current_test_batch: int = 0 num_test_batches: int = 0 current_pred_batch: int = 0 num_pred_batches: int = 0 current_epoch: int = 0 num_epochs: int = 0 optim_step: int = 0 split_idx: Optional[int] = None def reset_training(self, split_idx: int, num_epochs: int) -> None: """ To reset the whole trainer state. """ self.called = TrainerCall.TRAIN self.should_stop = False self.current_train_batch = 0 self.num_train_batches = 0 self.current_val_batch = 0 self.num_val_batches = 0 self.current_epoch = 0 self.num_epochs = num_epochs self.optim_step = 0 self.split_idx = split_idx def reset_epoch(self, train_loader: DataLoader, current_epoch: int) -> None: """ To reset at the beginning of a new epoch. """ self.stage = TrainerStage.TRAIN self.current_train_batch = 0 self.num_train_batches = len(train_loader) self.current_epoch = current_epoch def reset_validation( self, split_idx: int, val_loader: DataLoader, in_training: bool = True ) -> None: """ To reset the validation state. """ if not in_training: self.called = TrainerCall.VALIDATE self.stage = TrainerStage.EVAL self.current_val_batch = 0 self.num_val_batches = len(val_loader) self.split_idx = split_idx def reset_test(self, dataloader: DataLoader) -> None: """ To reset the test state. """ self.stage = TrainerStage.EVAL self.called = TrainerCall.TEST self.current_test_batch = 0 self.num_test_batches = len(dataloader) self.split_idx = None def reset_prediction(self, dataloader: DataLoader) -> None: """ To reset the prediction state. """ self.stage = TrainerStage.PRED self.called = TrainerCall.PREDICT self.current_pred_batch = 0 self.num_pred_batches = len(dataloader) self.split_idx = None def state_dict(self) -> dict[str, Any]: """ Returns the trainer state as a dict. Returns -------- dict[str, Any] """ return self.to_dict() def load_state_dict(self, state_dict: dict[str, Any]) -> None: """ To reload a trainer state given a ``state_dict``. Parameters ---------- state_dict : dict[str, Any] The trainer state returned by :py:meth:`state_dict`. """ self.__dict__.update(state_dict)