Esempio n. 1
0
    def train(self, batch_size, fold_idx, normalize=True, augment=False):
        """ Dataset for model training
        Args:
            batch_size: int, number of images in a batch
            fold_idx: int, index of fold, from 0 to n_folds - 1
            normalize: bool, wether to normalize training data
                                    with Welford's online algorithm.
            augment: bool, wether to use augmentation or not
        Returns:
            data: TensroFlow dataset
            steps: int, number of steps in train epoch
        """
        if not (fold_idx >= 0 and fold_idx < self.n_folds):
            raise Exception(('Fold index {} is out of expected range:' +
                             '  [0, {}]').format(fold_idx, self.n_folds - 1))

        if normalize and augment:
            raise Exception('Both augmentations and normalization ' +
                            'with Welford algo is not supported ')

        print(' ... Generating Training Dataset ... ')
        if self.n_folds == 1:
            train_idx = range(0, len(self.filenames))
        else:
            train_idx, _ = list(self.kf.split(self.filenames))[fold_idx]
        filenames = np.array(self.filenames)[train_idx]
        labels = np.array(self.labels)[train_idx]
        steps = math.ceil(len(filenames) / batch_size)
        if normalize:
            mean, std = Normalizer.calc_mean_and_std(filenames, self.img_size)
            mean = np.array([mean['red'], mean['green'], mean['blue']])
            std = np.array([std['red'], std['green'], std['blue']])
        else:
            # values taken from ImageNet Dataset
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
        self.normalizer = Normalizer(mean, std)
        data = tf.data.Dataset.from_tensor_slices(
            (tf.constant(filenames), tf.constant(labels)))
        data = data.map(self.parse_fn)
        if augment:
            augs = [self.flip, self.color, self.rotate, self.zoom]
            for f in augs:
                data = data.map(f, num_parallel_calls=4)
            data = data.map(self.drop, num_parallel_calls=4)
        data = data.shuffle(buffer_size=len(filenames))
        data = data.batch(batch_size)
        data = data.prefetch(1)
        return data, steps