Beispiel #1
0
def create_dataloaders(data_path,
                       batch_size,
                       test_ratio,
                       split=None,
                       only_classes=None,
                       only_one_sample=False,
                       load_on_request=False,
                       add_data=None,
                       p=1,
                       bw=False,
                       color=False):
    """
    Create data loaders from ImageDataSet according parameters. If split is provided, it is used by the data loader.
    :param data_path: path of root directory of data set, while directories 'photo' and 'sketch' are sub directories
    :param batch_size:
    :param test_ratio: 0.1 means 10% test data
    :param split: optional train/test split
    :param only_classes: optional list of folder names to retrieve training data from
    :param only_one_sample: Load only one sketch and one image
    :param num_workers: number of workers threads for loading sketches and images from drive
    :return: train and test dataloaders and train and test split
    """
    if bw and color:
        raise (RuntimeError(
            "Can't do both black and white and coloring at the same time."))
    if only_one_sample:
        test_ratio = 0.5
        batch_size = 1

    data_set = data.ImageDataSet(root_dir=data_path,
                                 transform=get_transform(),
                                 only_classes=only_classes,
                                 only_one_sample=only_one_sample,
                                 load_on_request=load_on_request,
                                 bw=bw,
                                 color=color)
    if split is None:
        perm = torch.randperm(len(data_set))
        train_split, test_split = perm[:math.ceil(
            len(data_set) *
            (1 -
             test_ratio))], perm[math.ceil(len(data_set) * (1 - test_ratio)):]
    else:
        train_split, test_split = split[0], split[1]
    if add_data:
        data_set_add = data.ImageDataSet(root_dir=add_data,
                                         transform=get_transform(),
                                         only_classes=only_classes,
                                         only_one_sample=only_one_sample,
                                         load_on_request=load_on_request,
                                         bw=bw,
                                         color=color)
        dataloader_train = data.CompositeDataloader(
            DataLoader(Subset(data_set, train_split),
                       batch_size=batch_size,
                       shuffle=True,
                       num_workers=0,
                       drop_last=True),
            DataLoader(data_set_add,
                       batch_size=batch_size,
                       shuffle=True,
                       num_workers=0,
                       drop_last=True),
            p=p,
            anneal_rate=0.99)
    else:
        dataloader_train = DataLoader(Subset(data_set, train_split),
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=0,
                                      drop_last=True)
        dataloader_train.p = p
        dataloader_train.anneal_p = lambda *args: None
    dataloader_test = DataLoader(Subset(data_set, test_split),
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=0)

    return dataloader_train, dataloader_test, train_split, test_split