Beispiel #1
0
def get_train_val_split(data: Union[Dataset, ConcatDataset],
                        **kwargs) -> Tuple[DataLoader, DataLoader]:
    """Creates validation and train dataloader from the Data_set object.

    Args:
        data (Data_Set): Object of the class Data_Set.
    kwargs:
        These arguments are passed as is to the pytorch DataLoader.

    Returns:
        Tuple[DataLoader, DataLoader]: train and validation data loader respectively.
    """

    if isinstance(data, ConcatDataset):
        val_datasets = []
        train_weights, val_weights = [], []
        for i in range(len(data.datasets)):
            train_weights += [1.0 / len(data.datasets[i])] * len(
                data.datasets[i])
            val_datasets.append(copy.copy(data.datasets[i]))
            val_datasets[-1].is_training(False)
            val_weights += [1.0 / len(val_datasets[-1])] * len(
                val_datasets[-1])
        val_data = ConcatDataset(val_datasets)
        train_weights = np.array(train_weights) / sum(train_weights)
        val_weights = np.array(val_weights) / sum(val_weights)
        return (
            DataLoader(data,
                       sampler=WeightedRandomSampler(
                           weights=train_weights,
                           num_samples=len(train_weights),
                           replacement=True,
                       ),
                       **kwargs),
            DataLoader(val_data,
                       sampler=WeightedRandomSampler(
                           weights=val_weights,
                           num_samples=len(val_weights),
                           replacement=True),
                       **kwargs),
        )
    else:
        data.is_training(True)
        val_data = copy.copy(data)
        val_data.is_training(False)
        return (
            DataLoader(data, **{
                **kwargs, "shuffle": True
            }),
            DataLoader(val_data, **{
                **kwargs, "shuffle": False
            }),
        )