Beispiel #1
0
    def next_batch(self):
        """Return the next minibatch of augmented data."""
        next_train_index = self.curr_train_index + self.hparams.batch_size
        if next_train_index > self.num_train:
            # Increase epoch number
            epoch = self.epochs + 1
            self.reset()
            self.epochs = epoch
        batched_data = (
            self.train_images[self.curr_train_index:self.curr_train_index +
                              self.hparams.batch_size],
            self.train_labels[self.curr_train_index:self.curr_train_index +
                              self.hparams.batch_size])
        final_imgs = []

        images, labels = batched_data
        for data in images:
            epoch_policy = self.good_policies[np.random.choice(
                len(self.good_policies))]
            final_img = augmentation_transforms.apply_policy(
                epoch_policy, data)
            final_img = augmentation_transforms.random_flip(
                augmentation_transforms.zero_pad_and_crop(final_img, 4))
            # Apply cutout
            final_img = augmentation_transforms.cutout_numpy(final_img)
            final_imgs.append(final_img)
        batched_data = (np.array(final_imgs, np.float32), labels)
        self.curr_train_index += self.hparams.batch_size
        return batched_data
Beispiel #2
0
def data_augmentation(unsup):
    augs = []
    unsup = unsup / 255.0
    mean, std = augmentation_transforms.get_mean_and_std()
    unsup = (unsup - mean) / std
    aug_policies = found_policies.cifar10_policies()
    for image in unsup:
        chosen_policy = aug_policies[np.random.choice(len(aug_policies))]
        aug = augmentation_transforms.apply_policy(chosen_policy, image)
        aug = augmentation_transforms.cutout_numpy(aug)
        augs.append(aug)
    return np.array(augs), unsup
Beispiel #3
0
def proc_and_dump_unsup_data(sub_set_data, aug_copy_num):
  ori_images = sub_set_data["images"].copy()

  image_idx = np.arange(len(ori_images))
  np.random.shuffle(image_idx)
  ori_images = ori_images[image_idx]

  # tf.logging.info("first 5 indexes after shuffling: {}".format(
  #     str(image_idx[:5])))

  ori_images = ori_images / 255.0
  mean, std = augmentation_transforms.get_mean_and_std()
  ori_images = (ori_images - mean) / std

  if FLAGS.task_name == "cifar10":
    aug_policies = found_policies.cifar10_policies()
  elif FLAGS.task_name == "svhn":
    aug_policies = found_policies.svhn_policies()

  example_list = []
  for image in ori_images:
    chosen_policy = aug_policies[np.random.choice(
        len(aug_policies))]
    aug_image = augmentation_transforms.apply_policy(
        chosen_policy, image)
    aug_image = augmentation_transforms.cutout_numpy(aug_image)

    # Write example to the tfrecord file
    example = tf.train.Example(features=tf.train.Features(
        feature={
            "ori_image": _float_feature(image.reshape(-1)),
            "aug_image": _float_feature(aug_image.reshape(-1)),
        }))
    example_list += [example]

  out_path = os.path.join(
      FLAGS.output_base_dir,
      format_unsup_filename(aug_copy_num),
  )
  save_tfrecord(example_list, out_path)