예제 #1
0
def mkval(args):
    """Returns train and validation datasets."""
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)

    val_tx = tv.transforms.Compose([
        tv.transforms.Resize((crop, crop)),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    if args.dataset == "cifar10":
        valid_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=val_tx,
                                        train=False,
                                        download=True)
    elif args.dataset == "cifar100":
        valid_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=val_tx,
                                         train=False,
                                         download=True)
    elif args.dataset == "imagenet2012":
        valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx)
    else:
        raise ValueError(f"Sorry, we have not spent time implementing the "
                         f"{args.dataset} dataset in the PyTorch codebase. "
                         f"In principle, it should be easy to add :)")

    if args.examples_per_class is not None:
        indices = fs.find_fewshot_indices(train_set, args.examples_per_class)
        train_set = torch.utils.data.Subset(train_set, indices=indices)

    micro_batch_size = args.batch_size // args.batch_split

    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=micro_batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=False)

    return valid_set, valid_loader
예제 #2
0
def mktrainval(args, logger):
    """Returns train and validation datasets."""
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)

    train_tx = tv.transforms.Compose([
        #tv.transforms.Resize((896, 896)),
        tv.transforms.Resize((precrop, precrop)),
        tv.transforms.RandomCrop((crop, crop)),
        tv.transforms.RandomHorizontalFlip(),
        tv.transforms.ToTensor(),
        #tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        tv.transforms.Normalize((0.43032281, 0.49672744, 0.3134248),
                                (0.08504857, 0.08000449, 0.10248923)),
    ])

    val_tx = tv.transforms.Compose([
        #tv.transforms.Resize((896, 896)),
        tv.transforms.Resize((crop, crop)),
        tv.transforms.ToTensor(),
        #tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        tv.transforms.Normalize((0.43032281, 0.49672744, 0.3134248),
                                (0.08504857, 0.08000449, 0.10248923)),
    ])

    if args.dataset == "cifar10":
        train_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=train_tx,
                                        train=True,
                                        download=True)
        valid_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=val_tx,
                                        train=False,
                                        download=True)
    elif args.dataset == "cifar100":
        train_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=train_tx,
                                         train=True,
                                         download=True)
        valid_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=val_tx,
                                         train=False,
                                         download=True)
    elif args.dataset == "imagenet2012":
        train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"),
                                            train_tx)
        valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx)
    else:
        raise ValueError(f"Sorry, we have not spent time implementing the "
                         f"{args.dataset} dataset in the PyTorch codebase. "
                         f"In principle, it should be easy to add :)")

    if args.examples_per_class is not None:
        logger.info(
            f"Looking for {args.examples_per_class} images per class...")
        indices = fs.find_fewshot_indices(train_set, args.examples_per_class)
        train_set = torch.utils.data.Subset(train_set, indices=indices)

    logger.info(f"Using a training set with {len(train_set)} images.")
    logger.info(f"Using a validation set with {len(valid_set)} images.")

    micro_batch_size = args.batch_size // args.batch_split

    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=micro_batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=False)

    if micro_batch_size <= len(train_set):
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=micro_batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=False)
    else:
        # In the few-shot cases, the total dataset size might be smaller than the batch-size.
        # In these cases, the default sampler doesn't repeat, so we need to make it do that
        # if we want to match the behaviour from the paper.
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=micro_batch_size,
            num_workers=args.workers,
            pin_memory=True,
            sampler=torch.utils.data.RandomSampler(
                train_set, replacement=True, num_samples=micro_batch_size))

    return train_set, valid_set, train_loader, valid_loader
예제 #3
0
def mktrainval(args, logger):
    """Returns train and validation datasets."""
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)

    if args.input_channels == 3:
        train_tx = tv.transforms.Compose([
            tv.transforms.Resize((precrop, precrop)),
            tv.transforms.RandomCrop((crop, crop)),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        val_tx = tv.transforms.Compose([
            tv.transforms.Resize((crop, crop)),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    elif args.input_channels == 2:
        train_tx = tv.transforms.Compose([
            tv.transforms.Resize((precrop, precrop)),
            tv.transforms.RandomCrop((crop, crop)),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5), (0.5, 0.5)),
        ])

        val_tx = tv.transforms.Compose([
            tv.transforms.Resize((crop, crop)),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5), (0.5, 0.5)),
        ])

    elif args.input_channels == 1:
        train_tx = tv.transforms.Compose([
            tv.transforms.Resize((precrop, precrop)),
            tv.transforms.RandomCrop((crop, crop)),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5), (0.5)),
        ])

        val_tx = tv.transforms.Compose([
            tv.transforms.Resize((crop, crop)),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5), (0.5)),
        ])

    if args.dataset == "cifar10":
        train_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=train_tx,
                                        train=True,
                                        download=True)
        valid_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=val_tx,
                                        train=False,
                                        download=True)
    elif args.dataset == "cifar100":
        train_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=train_tx,
                                         train=True,
                                         download=True)
        valid_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=val_tx,
                                         train=False,
                                         download=True)

    elif args.dataset == "imagenet2012":

        folder_path = pjoin(args.datadir, "train")
        files = sorted(glob.glob("%s/*/*.*" % folder_path))
        #labels = [int(file.split("/")[-2]) for file in files]
        labels = [class_dict[file.split("/")[-2]] for file in files]
        train_set = ImageFolder(files, labels, train_tx, crop)

        folder_path = pjoin(args.datadir, "val")
        files = sorted(glob.glob("%s/*/*.*" % folder_path))
        #labels = [int(file.split("/")[-2]) for file in files]
        labels = [class_dict[file.split("/")[-2]] for file in files]
        valid_set = ImageFolder(files, labels, val_tx, crop)
        #train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"), train_tx)
        #valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx)
    else:
        raise ValueError(f"Sorry, we have not spent time implementing the "
                         f"{args.dataset} dataset in the PyTorch codebase. "
                         f"In principle, it should be easy to add :)")

    if args.examples_per_class is not None:
        logger.info(
            f"Looking for {args.examples_per_class} images per class...")
        indices = fs.find_fewshot_indices(train_set, args.examples_per_class)
        train_set = torch.utils.data.Subset(train_set, indices=indices)

    logger.info(f"Using a training set with {len(train_set)} images.")
    logger.info(f"Using a validation set with {len(valid_set)} images.")

    micro_batch_size = args.batch_size // args.batch_split
    micro_batch_size_val = 4 * micro_batch_size

    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=micro_batch_size_val,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=False)

    if micro_batch_size <= len(train_set):
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=micro_batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=False)
        train_loader_val = torch.utils.data.DataLoader(
            train_set,
            batch_size=micro_batch_size_val,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=False)
    else:
        # In the few-shot cases, the total dataset size might be smaller than the batch-size.
        # In these cases, the default sampler doesn't repeat, so we need to make it do that
        # if we want to match the behaviour from the paper.
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=micro_batch_size,
            num_workers=args.workers,
            pin_memory=True,
            sampler=torch.utils.data.RandomSampler(
                train_set, replacement=True, num_samples=micro_batch_size))

    return train_set, valid_set, train_loader, valid_loader, train_loader_val