def mktrainval(args, logger): """Returns train and validation datasets.""" precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset) train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) val_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) if args.dataset == "cifar10": train_set = tv.datasets.CIFAR10(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR10(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "cifar100": train_set = tv.datasets.CIFAR100(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR100(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "imagenet2012": train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"), train_tx) valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx) else: raise ValueError(f"Sorry, we have not spent time implementing the " f"{args.dataset} dataset in the PyTorch codebase. " f"In principle, it should be easy to add :)") if args.examples_per_class is not None: logger.info( f"Looking for {args.examples_per_class} images per class...") indices = fs.find_fewshot_indices(train_set, args.examples_per_class) train_set = torch.utils.data.Subset(train_set, indices=indices) logger.info(f"Using a training set with {len(train_set)} images.") logger.info(f"Using a validation set with {len(valid_set)} images.") micro_batch_size = args.batch // args.batch_split valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=micro_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) if micro_batch_size <= len(train_set): train_loader = torch.utils.data.DataLoader(train_set, batch_size=micro_batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False) else: # In the few-shot cases, the total dataset size might be smaller than the batch-size. # In these cases, the default sampler doesn't repeat, so we need to make it do that # if we want to match the behaviour from the paper. train_loader = torch.utils.data.DataLoader( train_set, batch_size=micro_batch_size, num_workers=args.workers, pin_memory=True, sampler=torch.utils.data.RandomSampler( train_set, replacement=True, num_samples=micro_batch_size)) return train_set, valid_set, train_loader, valid_loader
def _mktrainval(args, logger): """Returns train and validation datasets.""" precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset) if args.test_run: # save memory precrop, crop = 64, 56 train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) val_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) collate_fn = None n_train = None micro_batch_size = args.batch // args.batch_split if args.dataset == "cifar10": train_set = tv.datasets.CIFAR10(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR10(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "cifar100": train_set = tv.datasets.CIFAR100(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR100(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "imagenet2012": train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"), transform=train_tx) valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), transform=val_tx) elif args.dataset.startswith('objectnet') or args.dataset.startswith('imageneta'): # objectnet and objectnet_bbox and objectnet_no_bbox identifier = 'objectnet' if args.dataset.startswith('objectnet') else 'imageneta' valid_set = tv.datasets.ImageFolder(f"../datasets/{identifier}/", transform=val_tx) if args.inpaint == 'none': if args.dataset == 'objectnet' or args.dataset == 'imageneta': train_set = tv.datasets.ImageFolder(pjoin(args.datadir, f"train_{args.dataset}"), transform=train_tx) else: # For only images with or w/o bounding box train_bbox_file = '../datasets/imagenet/LOC_train_solution_size.csv' df = pd.read_csv(train_bbox_file) filenames = set(df[df.bbox_ratio <= args.bbox_max_ratio].ImageId) if args.dataset == f"{identifier}_no_bbox": is_valid_file = lambda path: os.path.basename(path).split('.')[0] not in filenames elif args.dataset == f"{identifier}_bbox": is_valid_file = lambda path: os.path.basename(path).split('.')[0] in filenames else: raise NotImplementedError() train_set = tv.datasets.ImageFolder( pjoin(args.datadir, f"train_{identifier}"), is_valid_file=is_valid_file, transform=train_tx) else: # do inpainting train_tx = tv.transforms.Compose([ data_utils.Resize((precrop, precrop)), data_utils.RandomCrop((crop, crop)), data_utils.RandomHorizontalFlip(), data_utils.ToTensor(), data_utils.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) train_set = ImagenetBoundingBoxFolder( root=f"../datasets/imagenet/train_{identifier}", bbox_file='../datasets/imagenet/LOC_train_solution.csv', transform=train_tx) collate_fn = bbox_collate n_train = len(train_set) * 2 micro_batch_size //= 2 else: raise ValueError(f"Sorry, we have not spent time implementing the " f"{args.dataset} dataset in the PyTorch codebase. " f"In principle, it should be easy to add :)") if args.examples_per_class is not None: logger.info(f"Looking for {args.examples_per_class} images per class...") indices = fs.find_fewshot_indices(train_set, args.examples_per_class) train_set = torch.utils.data.Subset(train_set, indices=indices) logger.info(f"Using a training set with {len(train_set)} images.") logger.info(f"Using a validation set with {len(valid_set)} images.") valid_loader = torch.utils.data.DataLoader( valid_set, batch_size=micro_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) if micro_batch_size <= len(train_set): train_loader = torch.utils.data.DataLoader( train_set, batch_size=micro_batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False, collate_fn=collate_fn) else: # In the few-shot cases, the total dataset size might be smaller than the batch-size. # In these cases, the default sampler doesn't repeat, so we need to make it do that # if we want to match the behaviour from the paper. train_loader = torch.utils.data.DataLoader( train_set, batch_size=micro_batch_size, num_workers=args.workers, pin_memory=True, sampler=torch.utils.data.RandomSampler(train_set, replacement=True, num_samples=micro_batch_size), collate_fn=collate_fn) if n_train is None: n_train = len(train_set) return n_train, len(valid_set.classes), train_loader, valid_loader