Ejemplo n.º 1
0
def AutoAugment(img):
    """
    Augmentation policy learned by RL.  From:
        https://arxiv.org/abs/1805.09501
    """
    good_policies = found_policies.good_policies()
    policy = good_policies[np.random.choice(len(good_policies))]
    final_img = augmentation_transforms.apply_policy(policy, img)
    return final_img
Ejemplo n.º 2
0
    def __init__(self, hparams):
        self.hparams = hparams
        self.epochs = 0
        self.curr_train_index = 0

        all_labels = []

        self.good_policies = found_policies.good_policies()

        # Determine how many databatched to load
        num_data_batches_to_load = 5
        total_batches_to_load = num_data_batches_to_load
        train_batches_to_load = total_batches_to_load
        assert hparams.train_size + hparams.validation_size <= 50000
        if hparams.eval_test:
            total_batches_to_load += 1
        # Determine how many images we have loaded
        total_dataset_size = 10000 * num_data_batches_to_load
        train_dataset_size = total_dataset_size
        if hparams.eval_test:
            total_dataset_size += 10000

        if hparams.dataset == 'cifar10':
            all_data = np.empty((total_batches_to_load, 10000, 3072),
                                dtype=np.uint8)
        elif hparams.dataset == 'cifar100':
            assert num_data_batches_to_load == 5
            all_data = np.empty((1, 50000, 3072), dtype=np.uint8)
            if hparams.eval_test:
                test_data = np.empty((1, 10000, 3072), dtype=np.uint8)
        if hparams.dataset == 'cifar10':
            tf.logging.info('Cifar10')
            datafiles = [
                'data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4',
                'data_batch_5'
            ]

            datafiles = datafiles[:train_batches_to_load]
            if hparams.eval_test:
                datafiles.append('test_batch')
            num_classes = 10
        elif hparams.dataset == 'cifar100':
            datafiles = ['train']
            if hparams.eval_test:
                datafiles.append('test')
            num_classes = 100
        else:
            raise NotImplementedError('Unimplemented dataset: ',
                                      hparams.dataset)
        if hparams.dataset != 'test':
            for file_num, f in enumerate(datafiles):
                d = unpickle(os.path.join(hparams.data_path, f))
                if f == 'test':
                    test_data[0] = copy.deepcopy(d['data'])
                    all_data = np.concatenate([all_data, test_data], axis=1)
                else:
                    all_data[file_num] = copy.deepcopy(d['data'])
                if hparams.dataset == 'cifar10':
                    labels = np.array(d['labels'])
                else:
                    labels = np.array(d['fine_labels'])
                nsamples = len(labels)
                for idx in range(nsamples):
                    all_labels.append(labels[idx])

        all_data = all_data.reshape(total_dataset_size, 3072)
        all_data = all_data.reshape(-1, 3, 32, 32)
        all_data = all_data.transpose(0, 2, 3, 1).copy()
        all_data = all_data / 255.0
        mean = augmentation_transforms.MEANS
        std = augmentation_transforms.STDS
        tf.logging.info('mean:{}    std: {}'.format(mean, std))

        all_data = (all_data - mean) / std
        all_labels = np.eye(num_classes)[np.array(all_labels, dtype=np.int32)]
        assert len(all_data) == len(all_labels)
        tf.logging.info('In CIFAR10 loader, number of images: {}'.format(
            len(all_data)))

        # Break off test data
        if hparams.eval_test:
            self.test_images = all_data[train_dataset_size:]
            self.test_labels = all_labels[train_dataset_size:]

        # Shuffle the rest of the data
        all_data = all_data[:train_dataset_size]
        all_labels = all_labels[:train_dataset_size]
        np.random.seed(0)
        perm = np.arange(len(all_data))
        np.random.shuffle(perm)
        all_data = all_data[perm]
        all_labels = all_labels[perm]

        # Break into train and val
        train_size, val_size = hparams.train_size, hparams.validation_size
        assert 50000 >= train_size + val_size
        self.train_images = all_data[:train_size]
        self.train_labels = all_labels[:train_size]
        self.val_images = all_data[train_size:train_size + val_size]
        self.val_labels = all_labels[train_size:train_size + val_size]
        self.num_train = self.train_images.shape[0]