def get_train_val_dataloader(args): dataset_list = args.name assert isinstance(dataset_list, list) datasets = [] val_datasets = [] limit = args.limit for dname in dataset_list: if args.type == 'jigsaw': img_transformer, tile_transformer = get_jig_train_transformers(args) train_dataset = JigsawDataset(dname, split='train', val_size=args.val_size, img_transformer=img_transformer, tile_transformer=tile_transformer, jig_classes=args.aux_classes, bias_whole_image=args.bias_whole_image) val_dataset = JigsawTestDataset(dname, split='val', val_size=args.val_size, img_transformer=get_val_transformer(args), jig_classes=args.aux_classes) elif args.type == 'rotate': img_transformer = get_rot_train_transformers(args) train_dataset = RotateDataset(dname, split='train', val_size=args.val_size, img_transformer=img_transformer, rot_classes=args.aux_classes, bias_whole_image=args.bias_whole_image) val_dataset = RotateTestDataset(dname, split='val', val_size=args.val_size, img_transformer=get_val_transformer(args), rot_classes=args.aux_classes) if limit: train_dataset = Subset(train_dataset, limit) datasets.append(train_dataset) val_datasets.append(val_dataset) dataset = ConcatDataset(datasets) val_dataset = ConcatDataset(val_datasets) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader, val_loader
def get_train_dataloader(args, patches): dataset_list = args.source assert isinstance(dataset_list, list) train_datasets = [] val_datasets = [] img_transformer, tile_transformer = get_train_transformers(args) limit = args.limit_source for dname in dataset_list: name_train, name_val, labels_train, labels_val = get_split_dataset_info( join(dirname(__file__), 'txt_lists','%s_train.txt' % dname), args.val_size ) train_dataset = JigsawDataset(name_train, labels_train, patches=patches, img_transformer=img_transformer, tile_transformer=tile_transformer, jig_classes=args.jigsaw_n_classes, bias_whole_image=args.bias_whole_image) if limit: train_dataset = Subset(train_dataset, limit) train_datasets.append(train_dataset) # Validation test => subtracted from train split val_datasets.append( JigsawTestDataset(name_val, labels_val, img_transformer=get_val_transformer(args), patches=patches, jig_classes=args.jigsaw_n_classes)) train_dataset = ConcatDataset(train_datasets) val_dataset = ConcatDataset(val_datasets) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return train_loader, val_loader
def get_digital_train_dataloader(args, dname): images, labels = None, [] if dname == 'mnist': images, labels = load_mnist_dataset(split='train', flip_p=args.flip_p) dataset_digit = AdvDataset(images, labels) train_size, val_size = 0, 0 if args.limit_source and len(dataset_digit) > args.limit_source: train_size = args.limit_source else: train_size = (int)(len(dataset_digit) * (1 - args.val_size)) val_size = len(dataset_digit) - train_size train_set, val_set = torch.utils.data.random_split(dataset_digit, [train_size, val_size]) if val_size > (int)(train_size * args.val_size / (1 - args.val_size)): val_set = Subset(val_set, (int)(train_size * args.val_size / (1 - args.val_size))) train_set, val_set = ConcatDataset([train_set]), ConcatDataset([val_set]) loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader, val_loader
def get_train_dataloader(args, patches): dataset_list = args.source assert isinstance(dataset_list, list) datasets = [] val_datasets = [] img_transformer, tile_transformer = get_train_transformers(args) limit = args.limit_source for dname in dataset_list: if dname in new_nex_datasets: index_root = data_root = '/import/home/share/from_Nexperia_April2021/%s' % dname else: index_root = join(dirname(__file__),'correct_txt_lists') data_root = join(dirname(__file__),'kfold') name_train, labels_train = _dataset_info(join(index_root, "%s_train.txt" % dname)) name_val, labels_val = _dataset_info(join(index_root, "%s_val.txt" % dname)) train_dataset = JigsawNewDataset(data_root, name_train, labels_train, patches=patches, img_transformer=img_transformer, tile_transformer=tile_transformer, jig_classes=30, bias_whole_image=args.bias_whole_image) if limit: train_dataset = Subset(train_dataset, limit) datasets.append(train_dataset) val_datasets.append(JigsawTestNewDataset(data_root,name_val, labels_val, img_transformer=get_val_transformer(args),patches=patches, jig_classes=30)) dataset = ConcatDataset(datasets) val_dataset = ConcatDataset(val_datasets) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader, val_loader
def load_datasets(self, args, gpu, n_gpus): for (framework, language), dataset in self.child_datasets.items(): dataset.load_dataset(args, gpu, n_gpus, framework, language) self.share_chars() self.share_vocabs(args) train_datasets = [ self.child_datasets[self.id_to_framework[i]].train for i in range(len(self.child_datasets)) ] self.train = torch.utils.data.DataLoader(ConcatDataset(train_datasets), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=Collate(), pin_memory=True, drop_last=True) self.train_size = len(self.train.dataset) self.mean_label_length = sum( dataset.node_count for dataset in self.child_datasets.values()) / self.train_size val_datasets = [ self.child_datasets[self.id_to_framework[i]].val for i in range(len(self.child_datasets)) ] self.val = torch.utils.data.DataLoader( ConcatDataset(val_datasets), batch_size=1, shuffle=False, num_workers=args.workers, collate_fn=Collate(), pin_memory=True, ) self.val_size = len(self.val.dataset) test_datasets = [ self.child_datasets[self.id_to_framework[i]].test for i in range(len(self.child_datasets)) ] self.test = torch.utils.data.DataLoader( ConcatDataset(test_datasets), batch_size=1, shuffle=False, num_workers=args.workers, collate_fn=Collate(), pin_memory=True, ) self.test_size = len(self.test.dataset) if gpu == 0: batch = next(iter(self.train)) print(f"\nBatch content: {Batch.to_str(batch)}\n") print(flush=True)
def get_train_dataloader(args): dataset_list = args.source assert isinstance(dataset_list, list) datasets = [] val_datasets = [] img_transformer = get_train_transformers(args) val_trasformer = get_val_transformer(args) for dname in dataset_list: name_train, name_val, labels_train, labels_val = get_split_dataset_info( join(dirname(__file__), 'txt_lists', dname + '.txt'), args.val_size) #batch of cropped and shuffle images for train args.betJigen train_dataset = Dataset(name_train, labels_train, args.path_dataset, img_transformer=img_transformer, betaJigen=args.betaJigen, rotation=args.rotation, oddOneOut=args.oddOneOut) datasets.append(train_dataset) val_dataset = TestDataset(name_val, labels_val, args.path_dataset, img_transformer=val_trasformer, betaJigen=args.betaJigen, rotation=args.rotation, oddOneOut=args.oddOneOut) val_datasets.append(val_dataset) dataset = ConcatDataset(datasets) val_dataset = ConcatDataset(val_datasets) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader, val_loader
def get_train_dataloader(args, patches): dataset_list = args.source assert isinstance(dataset_list, list) datasets = [] val_datasets = [] img_transformer, tile_transformer = get_train_transformers(args) limit = args.limit_source for dname in dataset_list: # name_train, name_val, labels_train, labels_val = get_split_dataset_info(join(dirname(__file__), 'txt_lists', '%s_train.txt' % dname), args.val_size) name_train, labels_train = _dataset_info( join(dirname(__file__), 'correct_txt_lists', '%s_train_kfold.txt' % dname)) name_val, labels_val = _dataset_info( join(dirname(__file__), 'correct_txt_lists', '%s_crossval_kfold.txt' % dname)) train_dataset = JigsawNewDataset( name_train, labels_train, patches=patches, img_transformer=img_transformer, tile_transformer=tile_transformer, jig_classes=30, bias_whole_image=args.bias_whole_image) if limit: train_dataset = Subset(train_dataset, limit) datasets.append(train_dataset) val_datasets.append( JigsawTestNewDataset(name_val, labels_val, img_transformer=get_val_transformer(args), patches=patches, jig_classes=30)) dataset = ConcatDataset(datasets) val_dataset = ConcatDataset(val_datasets) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader, val_loader
def get_val_dataloader(args, patches=False): if args.stylized: names, labels = _dataset_info( join(dirname(__file__), 'txt_lists', 'Stylized' + args.dataset, "{}_target".format(args.target), '%s_test.txt' % args.target)) else: names, labels = _dataset_info( join(dirname(__file__), 'txt_lists', 'Vanilla' + args.dataset, '%s_test.txt' % args.target)) img_tr = get_val_transformer(args) val_dataset = JigsawTestDataset(names, labels, patches=patches, img_transformer=img_tr, jig_classes=args.jigsaw_n_classes) if args.limit_target and len(val_dataset) > args.limit_target: val_dataset = Subset(val_dataset, args.limit_target) print("Using %d subset of val dataset" % args.limit_target) dataset = ConcatDataset([val_dataset]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader
def get_val_dataloader(args, patches=False): dname = args.target if dname in new_nex_datasets: index_root = data_root = '/import/home/share/from_Nexperia_April2021/%s' % dname else: index_root = join(dirname(__file__), 'correct_txt_lists') data_root = join(dirname(__file__), 'kfold') names, labels = _dataset_info(join(index_root, "%s_val.txt" % dname)) img_tr = get_nex_val_transformer(args) val_dataset = JigsawTestNewDataset(data_root, names, labels, patches=patches, img_transformer=img_tr, jig_classes=30) if args.limit_target and len(val_dataset) > args.limit_target: val_dataset = Subset(val_dataset, args.limit_target) print("Using %d subset of val dataset" % args.limit_target) dataset = ConcatDataset([val_dataset]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=False) return loader
def get_jigsaw_dataloader(args): #Only for DA names, labels = _dataset_info( join(dirname(__file__), 'txt_lists', args.target + '.txt')) img_tr = get_train_transformers(args) train_dataset = Dataset(names, labels, args.path_dataset, img_transformer=img_tr, beta_scrambled=args.beta_scrambled, beta_rotated=args.beta_rotated, beta_odd=args.beta_odd, rotation=False, odd=False) dataset = ConcatDataset([train_dataset]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader
def get_tgt_dataloader(args, patches=False): ''' Load whole domain dataset ''' img_tr = get_nex_val_transformer(args) dname = args.target if dname in new_nex_datasets: index_root = data_root = '/import/home/share/from_Nexperia_April2021/%s' % dname else: index_root = join(dirname(__file__),'correct_txt_lists') data_root = join(dirname(__file__),'kfold') if args.downsample_target: name_train, labels_train = _dataset_info(join(index_root,"%s_train_down.txt" % dname)) name_val, labels_val = _dataset_info(join(index_root, "%s_val_down.txt" % dname)) name_test, labels_test = _dataset_info(join(index_root, "%s_test_down.txt" % dname)) else: name_train, labels_train = _dataset_info(join(index_root,"%s_train.txt" % dname)) name_val, labels_val = _dataset_info(join(index_root, "%s_val.txt" % dname)) name_test, labels_test = _dataset_info(join(index_root, "%s_test.txt" % dname)) tgt_train_dataset = JigsawTestNewDataset(data_root, name_train, labels_train, patches=patches, img_transformer=img_tr,jig_classes=30) tgt_val_dataset = JigsawTestNewDataset(data_root, name_val, labels_val, patches=patches, img_transformer=img_tr, jig_classes=30) tgt_test_dataset = JigsawTestNewDataset(data_root, name_test, labels_test, patches=patches, img_transformer=img_tr, jig_classes=30) tgt_dataset = ConcatDataset([tgt_train_dataset, tgt_val_dataset, tgt_test_dataset]) loader = torch.utils.data.DataLoader(tgt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=False) return loader
def get_digital_target_dataloader(args, dname): images, labels = None, [] if dname == 'mnist': images, labels = load_mnist_dataset(split='test') elif dname == 'svhn': images, labels = load_svhn_dataset(split='test') elif dname == 'usps': images, labels = load_usps_dataset(split='test') elif dname == 'mnist_m': images, labels = load_mnist_m_dataset(split='test') elif dname == 'syn': images, labels = load_syn_dataset(split='test') dataset_digit = AdvDataset(images, labels) if args.limit_target and len(dataset_digit) > args.limit_target: dataset_digit = Subset(dataset_digit, args.limit_target) dataset = ConcatDataset([dataset_digit]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader
def get_par_val_dataloader(args, patches=False): names, labels = _dataset_info( join(dirname(__file__), 'txt_lists', '%s_test.txt' % args.target)) img_tr = [transforms.Resize((args.image_size, args.image_size))] tile_tr = [ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] img_transformer = transforms.Compose(img_tr) tile_transformer = transforms.Compose(tile_tr) val_dataset = PARDataset(names, labels, patches=patches, img_transformer=img_transformer, tile_transformer=tile_transformer) if args.limit_target and len(val_dataset) > args.limit_target: val_dataset = Subset(val_dataset, args.limit_target) print("Using %d subset of val dataset" % args.limit_target) dataset = ConcatDataset([val_dataset]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader
def get_val_dataloader(args, patches=False): if args.target in pacs_datasets: names, labels = _dataset_info( join(dirname(__file__), 'txt_lists', '%s_test.txt' % args.target)) elif args.target in imagenet_datasets: names, labels = _dataset_info( join( join( join(dirname(__file__), 'txt_lists', 'imagenet_testDataPath.txt')))) else: print('Error: test dataset not found.') img_tr = get_val_transformer(args) val_dataset = PARTestDataset(names, labels, patches=patches, img_transformer=img_tr) if args.limit_target and len(val_dataset) > args.limit_target: val_dataset = Subset(val_dataset, args.limit_target) print("Using %d subset of val dataset" % args.limit_target) dataset = ConcatDataset([val_dataset]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader
def get_single_test_dataloader(args, dname, patches=False): names, labels = _dataset_info(join(dirname(__file__), 'txt_lists', '%s_test.txt' % dname)) img_tr = get_val_transformer(args) # JigsawTestDataset return [unsorted_image, permutation_order-0, class_label] test_dataset = JigsawTestDataset(names, labels, patches=patches, img_transformer=img_tr, jig_classes=args.jigsaw_n_classes) if args.limit_target and len(test_dataset) > args.limit_target: test_dataset = Subset(test_dataset, args.limit_target) print("Using %d subset of test dataset" % args.limit_target) dataset = ConcatDataset([test_dataset]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader
def load_sentences(self, sentences, args, framework: str, language: str): def switch(f, l, s): return s if (framework == f and language == l) else [] datasets = [ dataset.load_sentences(switch(f, l, sentences), args, language) for (f, l), dataset in self.child_datasets.items() ] return torch.utils.data.DataLoader(ConcatDataset(datasets), batch_size=1, shuffle=False, collate_fn=Collate())
def get_train_dataloader(args): dataset_list = args.source assert isinstance(dataset_list, list) datasets = [] val_datasets = [] img_transformer = get_train_transformers(args) limit = args.limit_source for dname in dataset_list: if dname in digits_datasets: return get_digital_train_dataloader(args, dname) name_train, name_val, labels_train, labels_val = get_split_dataset_info( join(dirname(__file__), 'txt_lists', '%s_train.txt' % dname), args.val_size) train_dataset = JigsawDataset(name_train, labels_train, img_transformer=img_transformer) if limit: train_dataset = Subset(train_dataset, limit) datasets.append(train_dataset) val_datasets.append( JigsawDataset(name_val, labels_val, img_transformer=get_val_transformer(args))) dataset = ConcatDataset(datasets) val_dataset = ConcatDataset(val_datasets) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader, val_loader
def append_adversarial_samples(args, data_loader, adv_data, adv_labels): datasets = data_loader.dataset.datasets dataset_adv = AdvDataset(adv_data, adv_labels) datasets.append(dataset_adv) dataset = ConcatDataset(datasets) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) return loader
def get_val_dataloader(args): names, labels = _dataset_info( join(dirname(__file__), 'txt_lists', '%s_test.txt' % args.target)) img_tr = get_val_transformer(args) val_dataset = BaselineDataset(names, labels, img_transformer=img_tr) if args.limit_target and len(val_dataset) > args.limit_target: val_dataset = Subset(val_dataset, args.limit_target) print("Using %d subset of val dataset" % args.limit_target) dataset = ConcatDataset([val_dataset]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader
def get_test_dataloader(args): img_tr = get_val_transformer(args) name = args.name if args.type == 'jigsaw': val_dataset = JigsawTestDataset(name, split='test', img_transformer=img_tr, jig_classes=args.aux_classes) elif args.type == 'rotate': val_dataset = RotateTestDataset(name, split='test', img_transformer=img_tr, rot_classes=args.aux_classes) if args.limit and len(val_dataset) > args.limit: val_dataset = Subset(val_dataset, args.limit) print("Using %d subset of dataset" % args.limit) dataset = ConcatDataset([val_dataset]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader
def get_val_dataloader(args): names, labels = _dataset_info( join(dirname(__file__), 'txt_lists', args.target + '.txt')) img_tr = get_val_transformer(args) val_dataset = TestDataset(names, labels, args.path_dataset, img_transformer=img_tr) dataset = ConcatDataset([val_dataset]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader
def get_test_dataloader(args): name = args.name mode = args.get('mode', 'RGB') loaders = [] img_trs = get_multi_crop_transformers(args) for img_tr in img_trs: if args.type == 'jigsaw': val_dataset = JigsawTestDataset(name, split='test', img_transformer=img_tr, jig_classes=args.aux_classes) elif args.type == 'rotate': val_dataset = RotateTestDataset(name, split='test', img_transformer=img_tr, rot_classes=args.aux_classes, mode=mode) elif args.type == 'image': val_dataset = ImageTestDataset(name, split='test', img_transformer=img_tr, mode=mode) if args.limit and len(val_dataset) > args.limit: val_dataset = Subset(val_dataset, args.limit) print("Using %d subset of dataset" % args.limit) dataset = ConcatDataset([val_dataset]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) if args.get('multi_crop', False): loaders.append(loader) else: return loader return loaders
def get_trainTargetAsSource_dataloader(args): #used to create dataset of target images ready for jigsaw task ( only in training da!!) names, labels = _dataset_info( join(dirname(__file__), 'txt_lists', args.target + '.txt')) img_transformer = get_train_transformers(args) train_dataset = Dataset(names, labels, args.path_dataset, img_transformer=img_transformer, betaJigen=args.betaJigen, rotation=args.rotation, oddOneOut=args.oddOneOut) #val_dataset = TestDataset(names, labels,args.path_dataset, img_transformer=img_tr,betaJigen = args.betaJigen) dataset = ConcatDataset([train_dataset]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) return loader
def get_jigsaw_val_dataloader(args, patches=False): if args.stylized: names, labels = _dataset_info( join(dirname(__file__), 'txt_lists', 'Stylized' + args.dataset, "{}_target".format(args.target), '%s_test.txt' % args.target)) else: names, labels = _dataset_info( join(dirname(__file__), 'txt_lists', 'Vanilla' + args.dataset, '%s_test.txt' % args.target)) img_tr = [transforms.Resize((args.image_size, args.image_size))] tile_tr = [ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] img_transformer = transforms.Compose(img_tr) tile_transformer = transforms.Compose(tile_tr) val_dataset = JigsawDataset(names, labels, patches=patches, img_transformer=img_transformer, tile_transformer=tile_transformer, jig_classes=args.jigsaw_n_classes, bias_whole_image=args.bias_whole_image, grid_size=args.grid_size) if args.limit_target and len(val_dataset) > args.limit_target: val_dataset = Subset(val_dataset, args.limit_target) print("Using %d subset of val dataset" % args.limit_target) dataset = ConcatDataset([val_dataset]) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader
def get_train_dataloader(args, patches): dataset_list = args.source assert isinstance(dataset_list, list) datasets = [] val_datasets = [] img_transformer, tile_transformer = get_train_transformers(args) limit = args.limit_source for dname in dataset_list: if args.stylized: name_train, name_val, labels_train, labels_val = get_split_dataset_info( join(dirname(__file__), 'txt_lists', 'Stylized' + args.dataset, "{}_target".format(args.target), '%s_train.txt' % dname), args.val_size) # print(name_train) else: name_train, name_val, labels_train, labels_val = get_split_dataset_info( join(dirname(__file__), 'txt_lists', 'Vanilla' + args.dataset, '%s_train.txt' % dname), args.val_size) train_dataset = JigsawDataset(name_train, labels_train, patches=patches, img_transformer=img_transformer, tile_transformer=tile_transformer, jig_classes=args.jigsaw_n_classes, bias_whole_image=args.bias_whole_image, grid_size=args.grid_size) if limit: train_dataset = Subset(train_dataset, limit) datasets.append(train_dataset) if args.jig_only: val_datasets.append( JigsawDataset(name_val, labels_val, patches=patches, img_transformer=img_transformer, tile_transformer=tile_transformer, jig_classes=args.jigsaw_n_classes, bias_whole_image=args.bias_whole_image, grid_size=args.grid_size)) else: val_datasets.append( JigsawTestDataset(name_val, labels_val, img_transformer=get_val_transformer(args), patches=patches, jig_classes=args.jigsaw_n_classes)) #val_datasets.append(JigsawTestDataset(name_val, labels_val, img_transformer=get_val_transformer(args), # patches=patches, jig_classes=args.jigsaw_n_classes)) dataset = ConcatDataset(datasets) val_dataset = ConcatDataset(val_datasets) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader, val_loader
def get_train_dataloader_JiGen(args, device, task="DG"): dataset_list = args.source assert isinstance(dataset_list, list) target = args.target if task == "DA": if target in dataset_list: dataset_list.remove(target) datasets = [] val_datasets = [] img_transformer, patch_transformer = get_train_transformers(args) val_trasformer = get_val_transformer(args) for dname in dataset_list: name_train, name_val, labels_train, labels_val = get_split_dataset_info( join(dirname(__file__), 'txt_lists', dname + '.txt'), args.val_size) if target == dname: name_train += name_val labels_train += labels_val train_dataset = JigsawDataset(name_train, labels_train, args.path_dataset, args.scrambled, args.rotated, args.odd, args.jigen_transf, args.grid_size, args.permutation_number, device, task, target_name=args.target, img_transformer=img_transformer, patch_transformer=patch_transformer) datasets.append(train_dataset) if target != dname: val_dataset = TestDataset(name_val, labels_val, args.path_dataset, img_transformer=val_trasformer) val_datasets.append(val_dataset) dataset = ConcatDataset(datasets) val_dataset = ConcatDataset(val_datasets) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False) return loader, val_loader
def get_train_dataloader(args, patches): dataset_list = args.source assert isinstance(dataset_list, list) datasets = [] val_datasets = [] img_transformer, tile_transformer = get_train_transformers(args) limit = args.limit_source if dataset_list[0] in pacs_datasets: for dname in dataset_list: name_train, name_val, labels_train, labels_val = get_split_dataset_info( join(dirname(__file__), 'txt_lists', '%s_train.txt' % dname), args.val_size) train_dataset = PARDataset(name_train, labels_train, patches=patches, img_transformer=img_transformer, tile_transformer=tile_transformer) if limit: train_dataset = Subset(train_dataset, limit) datasets.append(train_dataset) val_datasets.append( PARTestDataset(name_val, labels_val, img_transformer=get_val_transformer(args), patches=patches)) elif dataset_list[0] in imagenet_datasets: name_train, labels_train = _dataset_info( join( join(dirname(__file__), 'txt_lists', 'imagenet_trainDataPath.txt'))) train_dataset = PARDataset(name_train, labels_train, patches=patches, img_transformer=img_transformer, tile_transformer=tile_transformer) datasets.append(train_dataset) name_val, labels_val = _dataset_info( join( join(dirname(__file__), 'txt_lists', 'imagenet_valDataPath.txt'))) val_datasets.append( PARTestDataset(name_val, labels_val, img_transformer=get_val_transformer(args), patches=patches)) else: print('Error: dataset not found.') dataset = ConcatDataset(datasets) val_dataset = ConcatDataset(val_datasets) loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) return loader, val_loader