from__future__importannotationsfrombisectimportbisect_rightfromloggingimportgetLoggerfromtypingimportAny,Iterable,Sequenceimportnumpyasnpimportpandasaspdfrompydanticimportfield_validatorfromtyping_extensionsimportSelffromclinicadl.utils.dictionary.wordsimportDATASET_IDfromclinicadl.utils.typingimportDataFrameTypefrom..structuresimportSamplefrom.baseimportDatasetfrom.collectionimportCollectionDataset,CollectionDatasetConfigfrom.utilsimportCheckableDataset,OneSampleDatasetlogger=getLogger(__name__)classConcatDatasetConfig(CollectionDatasetConfig):""" Config class for ``ConcatDataset``. """@field_validator("datasets",mode="after")@classmethoddef_check_dataset_id_column(cls,datasets:tuple[Dataset,...])->tuple[Dataset,...]:returnsuper()._check_dataset_id_column(datasets)@classmethoddef_get_class(cls)->type[ConcatDataset]:"""Returns the class associated to this config class."""returnConcatDataset
[docs]classConcatDataset(OneSampleDataset,CollectionDataset,CheckableDataset):""" For assembling multiple :py:class:`~clinicadl.data.datasets.Dataset` (e.g., images coming from different BIDS datasets). ``ConcatDataset`` concatenates the input datasets, so the length of the new dataset will be equal to the sum of the lengths of each individual dataset. Parameters ---------- datasets : Iterable[Dataset] The ``Datasets`` to concatenate. Examples -------- .. code-block:: text bids_1 ├── sub-001 │ ├── ses-M000 │ │ └── pet │ │ └── sub-001_ses-M000_pet.nii.gz │ ... ... bids_2 ├── sub-A │ ├── ses-M003 │ │ └── pet │ │ └── sub-A_ses-M003_pet.nii.gz │ ... ... .. code-block:: python from clinicadl.data.datasets import BidsDataset, ConcatDataset from clinicadl.io.bids import BidsFileType bids_1 = BidsDataset("bids_1", file_type=BidsFileType(data_type="pet", suffix="pet")) bids_2 = BidsDataset("bids_2", file_type=BidsFileType(data_type="pet", suffix="pet")) full_dataset = ConcatDataset([bids_1, bids_2]) .. code-block:: python >>> len(bids_1) 4 >>> len(bids_2) 8 >>> len(full_dataset) 12 >>> full_dataset[0].participant_id, full_dataset[0].session_id ('sub-001', 'ses-M000') >>> full_dataset[4].participant_id, full_dataset[4].session_id ('sub-A', 'ses-M003') """_config_type=ConcatDatasetConfigdef__init__(self,datasets:Iterable[Dataset],):super().__init__(datasets=datasets)@propertydefdf(self):"The concatenation of the two underlying metadata DataFrames."returnsuper().df
[docs]defsubset(self,particpants_sessions:DataFrameType|Iterable[tuple[str,str]])->Self:sub_datasets=[]not_empty=Falsefordatasetinself.datasets:try:sub_datasets.append(dataset.subset(particpants_sessions))exceptRuntimeError:# empty datasetcontinueelse:not_empty=Trueifnotnot_empty:raiseRuntimeError("No (participant, session) pairs are in the dataset. This would lead to an empty dataset!")returntype(self)(sub_datasets,)
[docs]defget_sample_info(self,idx:int,column:str)->Any:dataset_idx,idx_in_dataset=self._get_dataset_and_rank(idx)try:returnself.datasets[dataset_idx].get_sample_info(idx_in_dataset,column)exceptKeyErrorase:raiseKeyError(f"No column named '{column}' in the metadata DataFrame of the dataset from which the sample is taken.")frome
def_get_dataset_and_rank(self,idx:int)->tuple[int,int]:""" Gets the dataset from which is the sample. """self._check_idx(idx)cum_len=np.cumsum([len(dataset)fordatasetinself.datasets])dataset_idx=bisect_right(cum_len,idx)ifdataset_idx>0:idx_in_dataset=int(idx-cum_len[dataset_idx-1])else:idx_in_dataset=idxreturndataset_idx,idx_in_dataset@staticmethoddef_merge_dfs(datasets:Sequence[Dataset])->pd.DataFrame:df:pd.DataFrame=pd.concat([dataset.dffordatasetindatasets],keys=range(len(datasets)),names=[DATASET_ID],)returndf.reset_index(drop=False,level=DATASET_ID,).reset_index(drop=True)