def get_train_val_split(data: Union[Dataset, ConcatDataset], **kwargs) -> Tuple[DataLoader, DataLoader]: """Creates validation and train dataloader from the Data_set object. Args: data (Data_Set): Object of the class Data_Set. kwargs: These arguments are passed as is to the pytorch DataLoader. Returns: Tuple[DataLoader, DataLoader]: train and validation data loader respectively. """ if isinstance(data, ConcatDataset): val_datasets = [] train_weights, val_weights = [], [] for i in range(len(data.datasets)): train_weights += [1.0 / len(data.datasets[i])] * len( data.datasets[i]) val_datasets.append(copy.copy(data.datasets[i])) val_datasets[-1].is_training(False) val_weights += [1.0 / len(val_datasets[-1])] * len( val_datasets[-1]) val_data = ConcatDataset(val_datasets) train_weights = np.array(train_weights) / sum(train_weights) val_weights = np.array(val_weights) / sum(val_weights) return ( DataLoader(data, sampler=WeightedRandomSampler( weights=train_weights, num_samples=len(train_weights), replacement=True, ), **kwargs), DataLoader(val_data, sampler=WeightedRandomSampler( weights=val_weights, num_samples=len(val_weights), replacement=True), **kwargs), ) else: data.is_training(True) val_data = copy.copy(data) val_data.is_training(False) return ( DataLoader(data, **{ **kwargs, "shuffle": True }), DataLoader(val_data, **{ **kwargs, "shuffle": False }), )