def _create_mapping_loader(config, dataset_class, tf3, partition, truncate=False, truncate_pc=None, tencrop=False, shuffle=False): if truncate: print("Note: creating mapping loader with truncate == True") if tencrop: assert (tf3 is None) imgs_list = [] if config.test_on_all_frame and partition == "test": for i in xrange(10): imgs_curr = dataset_class(root=config.dataset_root, transform=tf3, frame=i, crop=config.crop_by_bb, partition=partition) if truncate: print("shrinking dataset from %d" % len(imgs_curr)) imgs_curr = TruncatedDataset(imgs_curr, pc=truncate_pc) print("... to %d" % len(imgs_curr)) if tencrop: imgs_curr = TenCropAndFinish(imgs_curr, input_sz=config.input_sz, include_rgb=config.include_rgb) imgs_list.append(imgs_curr) else: for i in xrange(config.base_num): imgs_curr = dataset_class(root=config.dataset_root, transform=tf3, frame=config.base_frame + config.base_interval * i, crop=config.crop_by_bb, partition=partition) if truncate: print("shrinking dataset from %d" % len(imgs_curr)) imgs_curr = TruncatedDataset(imgs_curr, pc=truncate_pc) print("... to %d" % len(imgs_curr)) if tencrop: imgs_curr = TenCropAndFinish(imgs_curr, input_sz=config.input_sz, include_rgb=config.include_rgb) imgs_list.append(imgs_curr) imgs = ConcatDataset(imgs_list) dataloader = torch.utils.data.DataLoader(imgs, batch_size=config.batch_sz, # full batch shuffle=shuffle, num_workers=0, drop_last=False) if not shuffle: assert (isinstance(dataloader.sampler, torch.utils.data.sampler.SequentialSampler)) return dataloader
def _create_mapping_loader(config, dataset_class, tf3, partitions, target_transform=None, truncate=False, truncate_pc=None, tencrop=False, shuffle=False): if truncate: print("Note: creating mapping loader with truncate == True") if tencrop: assert (tf3 is None) imgs_list = [] for partition in partitions: if "STL10" == config.dataset: imgs_curr = dataset_class(root=config.dataset_root, transform=tf3, split=partition, target_transform=target_transform) elif config.dataset == "MNIST-adv": imgs_curr = dataset_class( root=config.dataset_root, transform=tf3, train=partition, target_transform=target_transform) + AdversarialDataset( config.adv_path, config.adv_n) else: imgs_curr = dataset_class(root=config.dataset_root, transform=tf3, train=partition, target_transform=target_transform) if truncate: print("shrinking dataset from %d" % len(imgs_curr)) imgs_curr = TruncatedDataset(imgs_curr, pc=truncate_pc) print("... to %d" % len(imgs_curr)) if tencrop: imgs_curr = TenCropAndFinish(imgs_curr, input_sz=config.input_sz, include_rgb=config.include_rgb) imgs_list.append(imgs_curr) imgs = ConcatDataset(imgs_list) dataloader = torch.utils.data.DataLoader( imgs, batch_size=config.batch_sz, # full batch shuffle=shuffle, num_workers=0, drop_last=False) if not shuffle: assert (isinstance(dataloader.sampler, torch.utils.data.sampler.SequentialSampler)) return dataloader