def create_dataloaders(data_path, batch_size, test_ratio, split=None, only_classes=None, only_one_sample=False, load_on_request=False, add_data=None, p=1, bw=False, color=False): """ Create data loaders from ImageDataSet according parameters. If split is provided, it is used by the data loader. :param data_path: path of root directory of data set, while directories 'photo' and 'sketch' are sub directories :param batch_size: :param test_ratio: 0.1 means 10% test data :param split: optional train/test split :param only_classes: optional list of folder names to retrieve training data from :param only_one_sample: Load only one sketch and one image :param num_workers: number of workers threads for loading sketches and images from drive :return: train and test dataloaders and train and test split """ if bw and color: raise (RuntimeError( "Can't do both black and white and coloring at the same time.")) if only_one_sample: test_ratio = 0.5 batch_size = 1 data_set = data.ImageDataSet(root_dir=data_path, transform=get_transform(), only_classes=only_classes, only_one_sample=only_one_sample, load_on_request=load_on_request, bw=bw, color=color) if split is None: perm = torch.randperm(len(data_set)) train_split, test_split = perm[:math.ceil( len(data_set) * (1 - test_ratio))], perm[math.ceil(len(data_set) * (1 - test_ratio)):] else: train_split, test_split = split[0], split[1] if add_data: data_set_add = data.ImageDataSet(root_dir=add_data, transform=get_transform(), only_classes=only_classes, only_one_sample=only_one_sample, load_on_request=load_on_request, bw=bw, color=color) dataloader_train = data.CompositeDataloader( DataLoader(Subset(data_set, train_split), batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True), DataLoader(data_set_add, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True), p=p, anneal_rate=0.99) else: dataloader_train = DataLoader(Subset(data_set, train_split), batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True) dataloader_train.p = p dataloader_train.anneal_p = lambda *args: None dataloader_test = DataLoader(Subset(data_set, test_split), batch_size=batch_size, shuffle=False, num_workers=0) return dataloader_train, dataloader_test, train_split, test_split