"""Config classes for loss functions natively supported in ``ClinicaDL``. Based on:torch:`PyTorch loss functions <nn.html#loss-functions>`."""fromtypingimportAny,List,Optionalimporttorchfrompydanticimport(NonNegativeFloat,PositiveFloat,field_validator,)fromclinicadl.utils.configimportObjectConfigfromclinicadl.utils.docimportadd_suffix_to_docfromclinicadl.utils.factoriesimportget_defaults_fromfrom.enumimportOrder,Reduction__all__=["LossConfig","NLLLossConfig","CrossEntropyLossConfig","BCELossConfig","BCEWithLogitsLossConfig","MultiMarginLossConfig","KLDivLossConfig","HuberLossConfig","SmoothL1LossConfig","L1LossConfig","MSELossConfig",]NLL_TORCH_DEFAULTS=get_defaults_from(torch.nn.NLLLoss)CROSS_ENTROPY_TORCH_DEFAULTS=get_defaults_from(torch.nn.CrossEntropyLoss)BCE_TORCH_DEFAULTS=get_defaults_from(torch.nn.BCELoss)BCE_LOGITS_TORCH_DEFAULTS=get_defaults_from(torch.nn.BCEWithLogitsLoss)MULTI_MARGIN_LOSS_TORCH_DEFAULTS=get_defaults_from(torch.nn.MultiMarginLoss)KL_DIV_LOSS_TORCH_DEFAULTS=get_defaults_from(torch.nn.KLDivLoss)HUBER_LOSS_TORCH_DEFAULTS=get_defaults_from(torch.nn.HuberLoss)SMOOTH_L1_LOSS_TORCH_DEFAULTS=get_defaults_from(torch.nn.SmoothL1Loss)L1_TORCH_DEFAULT=get_defaults_from(torch.nn.L1Loss)MSE_TORCH_DEFAULT=get_defaults_from(torch.nn.MSELoss)classLossConfig(ObjectConfig[torch.nn.Module]):"""Base config class for the loss function."""defget_object(self)->torch.nn.Module:""" Returns the loss function associated to this configuration, parametrized with the parameters passed by the user. Returns ------- torch.nn.Module: The PyTorch loss function. """params=self.to_raw_dict()if"weight"inparamsandparams["weight"]:params["weight"]=torch.Tensor(params["weight"])if"pos_weight"inparamsandparams["pos_weight"]:params["pos_weight"]=torch.Tensor(params["pos_weight"])associated_class=self._get_class()returnassociated_class(**params)@classmethoddef_get_class(cls)->type[torch.nn.Module]:"""Returns the loss function associated to this config class."""returngetattr(torch.nn,cls._get_name())DOC_WEIGHT=("``weight`` must be pass via a ``list`` and not via :py:class:`torch.Tensor`.")
[docs]@add_suffix_to_doc(DOC_WEIGHT)classNLLLossConfig(LossConfig):""" Config class for :py:class:`torch.nn.NLLLoss`. """weight:Optional[List[NonNegativeFloat]]=NLL_TORCH_DEFAULTS["weight"]ignore_index:int=NLL_TORCH_DEFAULTS["ignore_index"]reduction:Reduction=NLL_TORCH_DEFAULTS["reduction"]@field_validator("ignore_index")@classmethoddefvalidator_ignore_index(cls,v):ifisinstance(v,int):assert(v==-100or0<=v),"ignore_index must be a positive int (or -100 when disabled)."returnv
[docs]@add_suffix_to_doc(DOC_WEIGHT)classCrossEntropyLossConfig(NLLLossConfig):""" Config class for :py:class:`torch.nn.CrossEntropyLoss`. """weight:Optional[List[NonNegativeFloat]]=CROSS_ENTROPY_TORCH_DEFAULTS["weight"]ignore_index:int=CROSS_ENTROPY_TORCH_DEFAULTS["ignore_index"]reduction:Reduction=CROSS_ENTROPY_TORCH_DEFAULTS["reduction"]label_smoothing:NonNegativeFloat=CROSS_ENTROPY_TORCH_DEFAULTS["label_smoothing"]@field_validator("label_smoothing")@classmethoddefvalidator_label_smoothing(cls,v):ifisinstance(v,float):assert(0<=v<=1),f"label_smoothing must be between 0 and 1 but it has been set to {v}."returnv
[docs]@add_suffix_to_doc(DOC_WEIGHT)classBCELossConfig(LossConfig):""" Config class for :py:class:`torch.nn.BCELoss`. """weight:Optional[list[NonNegativeFloat]]=BCE_TORCH_DEFAULTS["weight"]reduction:Reduction=BCE_TORCH_DEFAULTS["reduction"]@field_validator("weight")@classmethoddefvalidator_weight(cls,v):ifvisnotNone:raiseValueError("'weight' with BCEWithLogitsLoss is not supported by ClinicaDL currently. Please leave it to None.")returnv
[docs]@add_suffix_to_doc(DOC_WEIGHT)classBCEWithLogitsLossConfig(BCELossConfig):""" Config class for :py:class:`torch.nn.BCEWithLogitsLoss`. """weight:Optional[List[NonNegativeFloat]]=BCE_LOGITS_TORCH_DEFAULTS["weight"]reduction:Reduction=BCE_LOGITS_TORCH_DEFAULTS["reduction"]pos_weight:Optional[List[Any]]=BCE_LOGITS_TORCH_DEFAULTS["pos_weight"]@field_validator("pos_weight")@classmethoddefvalidator_pos_weight(cls,v):ifisinstance(v,list):check=cls._recursive_float_check(v)ifnotcheck:raiseValueError(f"elements in pos_weight must be non-negative float, got: {v}")returnv@classmethoddef_recursive_float_check(cls,item):ifisinstance(item,list):returnall(cls._recursive_float_check(i)foriinitem)else:return(isinstance(item,float)orisinstance(item,int))anditem>=0
[docs]@add_suffix_to_doc(DOC_WEIGHT)classMultiMarginLossConfig(LossConfig):""" Config class for :py:class:`torch.nn.MultiMarginLoss`. """p:Order=MULTI_MARGIN_LOSS_TORCH_DEFAULTS["p"]margin:float=MULTI_MARGIN_LOSS_TORCH_DEFAULTS["margin"]weight:Optional[List[NonNegativeFloat]]=MULTI_MARGIN_LOSS_TORCH_DEFAULTS["weight"]reduction:Reduction=MULTI_MARGIN_LOSS_TORCH_DEFAULTS["reduction"]
[docs]classKLDivLossConfig(LossConfig):""" Config class for :py:class:`torch.nn.KLDivLoss`. """reduction:Reduction=KL_DIV_LOSS_TORCH_DEFAULTS["reduction"]log_target:bool=KL_DIV_LOSS_TORCH_DEFAULTS["log_target"]
[docs]classHuberLossConfig(LossConfig):""" Config class for :py:class:`torch.nn.HuberLoss`. """reduction:Reduction=HUBER_LOSS_TORCH_DEFAULTS["reduction"]delta:PositiveFloat=HUBER_LOSS_TORCH_DEFAULTS["delta"]
[docs]classSmoothL1LossConfig(LossConfig):""" Config class for :py:class:`torch.nn.SmoothL1Loss`. """reduction:Reduction=SMOOTH_L1_LOSS_TORCH_DEFAULTS["reduction"]beta:NonNegativeFloat=SMOOTH_L1_LOSS_TORCH_DEFAULTS["beta"]
[docs]classL1LossConfig(LossConfig):""" Config class for :py:class:`torch.nn.L1Loss`. """reduction:Reduction=L1_TORCH_DEFAULT["reduction"]
[docs]classMSELossConfig(LossConfig):""" Config class for :py:class:`torch.nn.MSELoss`. """reduction:Reduction=MSE_TORCH_DEFAULT["reduction"]