def next_batch(self): """Return the next minibatch of augmented data.""" next_train_index = self.curr_train_index + self.hparams.batch_size if next_train_index > self.num_train: # Increase epoch number epoch = self.epochs + 1 self.reset() self.epochs = epoch batched_data = ( self.train_images[self.curr_train_index:self.curr_train_index + self.hparams.batch_size], self.train_labels[self.curr_train_index:self.curr_train_index + self.hparams.batch_size]) final_imgs = [] images, labels = batched_data for data in images: epoch_policy = self.good_policies[np.random.choice( len(self.good_policies))] final_img = augmentation_transforms.apply_policy( epoch_policy, data) final_img = augmentation_transforms.random_flip( augmentation_transforms.zero_pad_and_crop(final_img, 4)) # Apply cutout final_img = augmentation_transforms.cutout_numpy(final_img) final_imgs.append(final_img) batched_data = (np.array(final_imgs, np.float32), labels) self.curr_train_index += self.hparams.batch_size return batched_data
def data_augmentation(unsup): augs = [] unsup = unsup / 255.0 mean, std = augmentation_transforms.get_mean_and_std() unsup = (unsup - mean) / std aug_policies = found_policies.cifar10_policies() for image in unsup: chosen_policy = aug_policies[np.random.choice(len(aug_policies))] aug = augmentation_transforms.apply_policy(chosen_policy, image) aug = augmentation_transforms.cutout_numpy(aug) augs.append(aug) return np.array(augs), unsup
def proc_and_dump_unsup_data(sub_set_data, aug_copy_num): ori_images = sub_set_data["images"].copy() image_idx = np.arange(len(ori_images)) np.random.shuffle(image_idx) ori_images = ori_images[image_idx] # tf.logging.info("first 5 indexes after shuffling: {}".format( # str(image_idx[:5]))) ori_images = ori_images / 255.0 mean, std = augmentation_transforms.get_mean_and_std() ori_images = (ori_images - mean) / std if FLAGS.task_name == "cifar10": aug_policies = found_policies.cifar10_policies() elif FLAGS.task_name == "svhn": aug_policies = found_policies.svhn_policies() example_list = [] for image in ori_images: chosen_policy = aug_policies[np.random.choice( len(aug_policies))] aug_image = augmentation_transforms.apply_policy( chosen_policy, image) aug_image = augmentation_transforms.cutout_numpy(aug_image) # Write example to the tfrecord file example = tf.train.Example(features=tf.train.Features( feature={ "ori_image": _float_feature(image.reshape(-1)), "aug_image": _float_feature(aug_image.reshape(-1)), })) example_list += [example] out_path = os.path.join( FLAGS.output_base_dir, format_unsup_filename(aug_copy_num), ) save_tfrecord(example_list, out_path)