Example #1
0
def mktrainval(args, logger):
    """Returns train and validation datasets."""
    precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)
    train_tx = tv.transforms.Compose([
        tv.transforms.Resize((precrop, precrop)),
        tv.transforms.RandomCrop((crop, crop)),
        tv.transforms.RandomHorizontalFlip(),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    val_tx = tv.transforms.Compose([
        tv.transforms.Resize((crop, crop)),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    if args.dataset == "cifar10":
        train_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=train_tx,
                                        train=True,
                                        download=True)
        valid_set = tv.datasets.CIFAR10(args.datadir,
                                        transform=val_tx,
                                        train=False,
                                        download=True)
    elif args.dataset == "cifar100":
        train_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=train_tx,
                                         train=True,
                                         download=True)
        valid_set = tv.datasets.CIFAR100(args.datadir,
                                         transform=val_tx,
                                         train=False,
                                         download=True)
    elif args.dataset == "imagenet2012":
        train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"),
                                            train_tx)
        valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), val_tx)
    else:
        raise ValueError(f"Sorry, we have not spent time implementing the "
                         f"{args.dataset} dataset in the PyTorch codebase. "
                         f"In principle, it should be easy to add :)")

    if args.examples_per_class is not None:
        logger.info(
            f"Looking for {args.examples_per_class} images per class...")
        indices = fs.find_fewshot_indices(train_set, args.examples_per_class)
        train_set = torch.utils.data.Subset(train_set, indices=indices)

    logger.info(f"Using a training set with {len(train_set)} images.")
    logger.info(f"Using a validation set with {len(valid_set)} images.")

    micro_batch_size = args.batch // args.batch_split

    valid_loader = torch.utils.data.DataLoader(valid_set,
                                               batch_size=micro_batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=False)

    if micro_batch_size <= len(train_set):
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=micro_batch_size,
                                                   shuffle=True,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=False)
    else:
        # In the few-shot cases, the total dataset size might be smaller than the batch-size.
        # In these cases, the default sampler doesn't repeat, so we need to make it do that
        # if we want to match the behaviour from the paper.
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=micro_batch_size,
            num_workers=args.workers,
            pin_memory=True,
            sampler=torch.utils.data.RandomSampler(
                train_set, replacement=True, num_samples=micro_batch_size))

    return train_set, valid_set, train_loader, valid_loader
Example #2
0
def _mktrainval(args, logger):
  """Returns train and validation datasets."""
  precrop, crop = bit_hyperrule.get_resolution_from_dataset(args.dataset)
  if args.test_run: # save memory
    precrop, crop = 64, 56

  train_tx = tv.transforms.Compose([
      tv.transforms.Resize((precrop, precrop)),
      tv.transforms.RandomCrop((crop, crop)),
      tv.transforms.RandomHorizontalFlip(),
      tv.transforms.ToTensor(),
      tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  ])
  val_tx = tv.transforms.Compose([
      tv.transforms.Resize((crop, crop)),
      tv.transforms.ToTensor(),
      tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  ])

  collate_fn = None
  n_train = None
  micro_batch_size = args.batch // args.batch_split
  if args.dataset == "cifar10":
    train_set = tv.datasets.CIFAR10(args.datadir, transform=train_tx, train=True, download=True)
    valid_set = tv.datasets.CIFAR10(args.datadir, transform=val_tx, train=False, download=True)
  elif args.dataset == "cifar100":
    train_set = tv.datasets.CIFAR100(args.datadir, transform=train_tx, train=True, download=True)
    valid_set = tv.datasets.CIFAR100(args.datadir, transform=val_tx, train=False, download=True)
  elif args.dataset == "imagenet2012":
    train_set = tv.datasets.ImageFolder(pjoin(args.datadir, "train"), transform=train_tx)
    valid_set = tv.datasets.ImageFolder(pjoin(args.datadir, "val"), transform=val_tx)
  elif args.dataset.startswith('objectnet') or args.dataset.startswith('imageneta'): # objectnet and objectnet_bbox and objectnet_no_bbox
    identifier = 'objectnet' if args.dataset.startswith('objectnet') else 'imageneta'
    valid_set = tv.datasets.ImageFolder(f"../datasets/{identifier}/", transform=val_tx)

    if args.inpaint == 'none':
      if args.dataset == 'objectnet' or args.dataset == 'imageneta':
        train_set = tv.datasets.ImageFolder(pjoin(args.datadir, f"train_{args.dataset}"),
                                            transform=train_tx)
      else: # For only images with or w/o bounding box
        train_bbox_file = '../datasets/imagenet/LOC_train_solution_size.csv'
        df = pd.read_csv(train_bbox_file)
        filenames = set(df[df.bbox_ratio <= args.bbox_max_ratio].ImageId)
        if args.dataset == f"{identifier}_no_bbox":
          is_valid_file = lambda path: os.path.basename(path).split('.')[0] not in filenames
        elif args.dataset == f"{identifier}_bbox":
          is_valid_file = lambda path: os.path.basename(path).split('.')[0] in filenames
        else:
          raise NotImplementedError()

        train_set = tv.datasets.ImageFolder(
          pjoin(args.datadir, f"train_{identifier}"),
          is_valid_file=is_valid_file,
          transform=train_tx)
    else: # do inpainting
      train_tx = tv.transforms.Compose([
        data_utils.Resize((precrop, precrop)),
        data_utils.RandomCrop((crop, crop)),
        data_utils.RandomHorizontalFlip(),
        data_utils.ToTensor(),
        data_utils.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
      ])

      train_set = ImagenetBoundingBoxFolder(
        root=f"../datasets/imagenet/train_{identifier}",
        bbox_file='../datasets/imagenet/LOC_train_solution.csv',
        transform=train_tx)
      collate_fn = bbox_collate
      n_train = len(train_set) * 2
      micro_batch_size //= 2

  else:
    raise ValueError(f"Sorry, we have not spent time implementing the "
                     f"{args.dataset} dataset in the PyTorch codebase. "
                     f"In principle, it should be easy to add :)")

  if args.examples_per_class is not None:
    logger.info(f"Looking for {args.examples_per_class} images per class...")
    indices = fs.find_fewshot_indices(train_set, args.examples_per_class)
    train_set = torch.utils.data.Subset(train_set, indices=indices)

  logger.info(f"Using a training set with {len(train_set)} images.")
  logger.info(f"Using a validation set with {len(valid_set)} images.")

  valid_loader = torch.utils.data.DataLoader(
      valid_set, batch_size=micro_batch_size, shuffle=False,
      num_workers=args.workers, pin_memory=True, drop_last=False)

  if micro_batch_size <= len(train_set):
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=micro_batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=False,
        collate_fn=collate_fn)
  else:
    # In the few-shot cases, the total dataset size might be smaller than the batch-size.
    # In these cases, the default sampler doesn't repeat, so we need to make it do that
    # if we want to match the behaviour from the paper.
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=micro_batch_size, num_workers=args.workers, pin_memory=True,
        sampler=torch.utils.data.RandomSampler(train_set, replacement=True, num_samples=micro_batch_size),
        collate_fn=collate_fn)

  if n_train is None:
    n_train = len(train_set)
  return n_train, len(valid_set.classes), train_loader, valid_loader