[docs]classTrainerState(ClinicaDLConfig):""" Represents the state of a :py:class:`~clinicadl.train.Trainer`. Attributes ---------- called : Optional[TrainerCall] The method of the ``Trainer`` that has been called. One of ``"train"``, ``"validate"`` or ``"test"``. ``None`` if no method has been called so far. stage : Optional[TrainerStage] Current action performed by the ``Trainer``. One of ``"training"`` or ``"evaluation"``. should_stop : bool Whether the training should be stopped at the end of the current epoch during :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. current_train_batch : int Index of the current training batch in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. num_train_batches : int Total number of training batches in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. current_val_batch : int Index of the current validation batch in :py:meth:`Trainer.validate <clinicadl.train.Trainer.validate>` or :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. num_val_batches : int Total number of validation batches in :py:meth:`Trainer.validate <clinicadl.train.Trainer.validate>` or :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. current_test_batch : int Index of the current test batch :py:meth:`Trainer.test <clinicadl.train.Trainer.test>`. num_test_batches : int Total number of test batches in :py:meth:`Trainer.test <clinicadl.train.Trainer.test>`. current_epoch : int Index of the current epoch in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. num_epochs : int Total number of epochs in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. optim_step : int The total number of optimization steps performed so far in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>`. split_idx : Optional[int] Index of the split on which training/validation is currently performed in :py:meth:`Trainer.train <clinicadl.train.Trainer.train>` or :py:meth:`Trainer.validate <clinicadl.train.Trainer.validate>`. """called:Optional[TrainerCall]=Nonestage:Optional[TrainerStage]=Noneshould_stop:bool=Falsecurrent_train_batch:int=0num_train_batches:int=0current_val_batch:int=0num_val_batches:int=0current_test_batch:int=0num_test_batches:int=0current_pred_batch:int=0num_pred_batches:int=0current_epoch:int=0num_epochs:int=0optim_step:int=0split_idx:Optional[int]=Nonedefreset_training(self,split_idx:int,num_epochs:int)->None:""" To reset the whole trainer state. """self.called=TrainerCall.TRAINself.should_stop=Falseself.current_train_batch=0self.num_train_batches=0self.current_val_batch=0self.num_val_batches=0self.current_epoch=0self.num_epochs=num_epochsself.optim_step=0self.split_idx=split_idxdefreset_epoch(self,train_loader:DataLoader,current_epoch:int)->None:""" To reset at the beginning of a new epoch. """self.stage=TrainerStage.TRAINself.current_train_batch=0self.num_train_batches=len(train_loader)self.current_epoch=current_epochdefreset_validation(self,split_idx:int,val_loader:DataLoader,in_training:bool=True)->None:""" To reset the validation state. """ifnotin_training:self.called=TrainerCall.VALIDATEself.stage=TrainerStage.EVALself.current_val_batch=0self.num_val_batches=len(val_loader)self.split_idx=split_idxdefreset_test(self,dataloader:DataLoader)->None:""" To reset the test state. """self.stage=TrainerStage.EVALself.called=TrainerCall.TESTself.current_test_batch=0self.num_test_batches=len(dataloader)self.split_idx=Nonedefreset_prediction(self,dataloader:DataLoader)->None:""" To reset the prediction state. """self.stage=TrainerStage.PREDself.called=TrainerCall.PREDICTself.current_pred_batch=0self.num_pred_batches=len(dataloader)self.split_idx=Nonedefstate_dict(self)->dict[str,Any]:""" Returns the trainer state as a dict. Returns -------- dict[str, Any] """returnself.to_dict()defload_state_dict(self,state_dict:dict[str,Any])->None:""" To reload a trainer state given a ``state_dict``. Parameters ---------- state_dict : dict[str, Any] The trainer state returned by :py:meth:`state_dict`. """self.__dict__.update(state_dict)