def mkval(args): """Returns train and validation datasets.""" precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset) val_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) if args.dataset == "cifar10": valid_set = tv.datasets.CIFAR10(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "cifar100": valid_set = tv.datasets.CIFAR100(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "imagenet2012": valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx) else: raise ValueError(f"Sorry, we have not spent time implementing the " f"{args.dataset} dataset in the PyTorch codebase. " f"In principle, it should be easy to add :)") if args.examples_per_class is not None: indices = fs.find_fewshot_indices(train_set, args.examples_per_class) train_set = torch.utils.data.Subset(train_set, indices=indices) micro_batch_size = args.batch_size // args.batch_split valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=micro_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) return valid_set, valid_loader
def mktrainval(args, logger): """Returns train and validation datasets.""" precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset) train_tx = tv.transforms.Compose([ #tv.transforms.Resize((896, 896)), tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), #tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), tv.transforms.Normalize((0.43032281, 0.49672744, 0.3134248), (0.08504857, 0.08000449, 0.10248923)), ]) val_tx = tv.transforms.Compose([ #tv.transforms.Resize((896, 896)), tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), #tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), tv.transforms.Normalize((0.43032281, 0.49672744, 0.3134248), (0.08504857, 0.08000449, 0.10248923)), ]) if args.dataset == "cifar10": train_set = tv.datasets.CIFAR10(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR10(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "cifar100": train_set = tv.datasets.CIFAR100(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR100(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "imagenet2012": train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"), train_tx) valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx) else: raise ValueError(f"Sorry, we have not spent time implementing the " f"{args.dataset} dataset in the PyTorch codebase. " f"In principle, it should be easy to add :)") if args.examples_per_class is not None: logger.info( f"Looking for {args.examples_per_class} images per class...") indices = fs.find_fewshot_indices(train_set, args.examples_per_class) train_set = torch.utils.data.Subset(train_set, indices=indices) logger.info(f"Using a training set with {len(train_set)} images.") logger.info(f"Using a validation set with {len(valid_set)} images.") micro_batch_size = args.batch_size // args.batch_split valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=micro_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) if micro_batch_size <= len(train_set): train_loader = torch.utils.data.DataLoader(train_set, batch_size=micro_batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False) else: # In the few-shot cases, the total dataset size might be smaller than the batch-size. # In these cases, the default sampler doesn't repeat, so we need to make it do that # if we want to match the behaviour from the paper. train_loader = torch.utils.data.DataLoader( train_set, batch_size=micro_batch_size, num_workers=args.workers, pin_memory=True, sampler=torch.utils.data.RandomSampler( train_set, replacement=True, num_samples=micro_batch_size)) return train_set, valid_set, train_loader, valid_loader
def mktrainval(args, logger): """Returns train and validation datasets.""" precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset) if args.input_channels == 3: train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) val_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) elif args.input_channels == 2: train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5), (0.5, 0.5)), ]) val_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5, 0.5), (0.5, 0.5)), ]) elif args.input_channels == 1: train_tx = tv.transforms.Compose([ tv.transforms.Resize((precrop, precrop)), tv.transforms.RandomCrop((crop, crop)), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5), (0.5)), ]) val_tx = tv.transforms.Compose([ tv.transforms.Resize((crop, crop)), tv.transforms.ToTensor(), tv.transforms.Normalize((0.5), (0.5)), ]) if args.dataset == "cifar10": train_set = tv.datasets.CIFAR10(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR10(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "cifar100": train_set = tv.datasets.CIFAR100(args.datadir, transform=train_tx, train=True, download=True) valid_set = tv.datasets.CIFAR100(args.datadir, transform=val_tx, train=False, download=True) elif args.dataset == "imagenet2012": folder_path = pjoin(args.datadir, "train") files = sorted(glob.glob("%s/*/*.*" % folder_path)) #labels = [int(file.split("/")[-2]) for file in files] labels = [class_dict[file.split("/")[-2]] for file in files] train_set = ImageFolder(files, labels, train_tx, crop) folder_path = pjoin(args.datadir, "val") files = sorted(glob.glob("%s/*/*.*" % folder_path)) #labels = [int(file.split("/")[-2]) for file in files] labels = [class_dict[file.split("/")[-2]] for file in files] valid_set = ImageFolder(files, labels, val_tx, crop) #train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"), train_tx) #valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx) else: raise ValueError(f"Sorry, we have not spent time implementing the " f"{args.dataset} dataset in the PyTorch codebase. " f"In principle, it should be easy to add :)") if args.examples_per_class is not None: logger.info( f"Looking for {args.examples_per_class} images per class...") indices = fs.find_fewshot_indices(train_set, args.examples_per_class) train_set = torch.utils.data.Subset(train_set, indices=indices) logger.info(f"Using a training set with {len(train_set)} images.") logger.info(f"Using a validation set with {len(valid_set)} images.") micro_batch_size = args.batch_size // args.batch_split micro_batch_size_val = 4 * micro_batch_size valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=micro_batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) if micro_batch_size <= len(train_set): train_loader = torch.utils.data.DataLoader(train_set, batch_size=micro_batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False) train_loader_val = torch.utils.data.DataLoader( train_set, batch_size=micro_batch_size_val, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False) else: # In the few-shot cases, the total dataset size might be smaller than the batch-size. # In these cases, the default sampler doesn't repeat, so we need to make it do that # if we want to match the behaviour from the paper. train_loader = torch.utils.data.DataLoader( train_set, batch_size=micro_batch_size, num_workers=args.workers, pin_memory=True, sampler=torch.utils.data.RandomSampler( train_set, replacement=True, num_samples=micro_batch_size)) return train_set, valid_set, train_loader, valid_loader, train_loader_val