def train_fn(self):
        assert len(self._train) > 0, "Training data not found."

        ds = tf.data.TFRecordDataset(filenames=self._train)

        ds = ds.shard(hvd.size(), hvd.rank())
        ds = ds.cache()
        ds = ds.shuffle(buffer_size=self._batch_size * 8, seed=self._seed)
        ds = ds.repeat()

        ds = ds.map(self.parse, num_parallel_calls=tf.data.experimental.AUTOTUNE)

        transforms = [
            RandomCrop3D((128, 128, 128)),
            RandomHorizontalFlip() if self.params.augment else None,
            Cast(dtype=tf.float32),
            NormalizeImages(),
            RandomBrightnessCorrection() if self.params.augment else None,
            OneHotLabels(n_classes=4),
        ]

        ds = ds.map(map_func=lambda x, y, mean, stdev: apply_transforms(x, y, mean, stdev, transforms=transforms),
                    num_parallel_calls=tf.data.experimental.AUTOTUNE)

        ds = ds.batch(batch_size=self._batch_size,
                      drop_remainder=True)

        ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

        return ds
    def synth_train_fn(self):
        """Synthetic data function for testing"""
        inputs = tf.random_uniform(self._xshape, dtype=tf.int32, minval=0, maxval=255, seed=self._seed,
                                   name='synth_inputs')
        masks = tf.random_uniform(self._yshape, dtype=tf.int32, minval=0, maxval=4, seed=self._seed,
                                  name='synth_masks')

        ds = tf.data.Dataset.from_tensors((inputs, masks))
        ds = ds.repeat()

        transforms = [
            Cast(dtype=tf.uint8),
            RandomCrop3D((128, 128, 128)),
            RandomHorizontalFlip() if self.params.augment else None,
            Cast(dtype=tf.float32),
            NormalizeImages(),
            RandomBrightnessCorrection() if self.params.augment else None,
            OneHotLabels(n_classes=4),
        ]

        ds = ds.map(map_func=lambda x, y: apply_transforms(x, y, transforms),
                    num_parallel_calls=tf.data.experimental.AUTOTUNE)
        ds = ds.batch(self._batch_size)
        ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

        return ds
Esempio n. 3
0
    def train_fn(self):
        """ Create dataset for training """
        if 'debug' in self.params.exec_mode:
            return self.synth_train_fn()

        assert len(self._train) > 0, "Training data not found."

        dataset = tf.data.TFRecordDataset(filenames=self._train)

        dataset = dataset.shard(hvd.size(), hvd.rank())
        dataset = dataset.cache()
        dataset = dataset.shuffle(buffer_size=self._batch_size * 8,
                                  seed=self._seed)
        dataset = dataset.repeat()

        dataset = dataset.map(self.parse,
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)

        transforms = [
            RandomCrop3D(self._input_shape),
            RandomHorizontalFlip() if self.params.augment else None,
            Cast(dtype=tf.float32),
            NormalizeImages(),
            RandomBrightnessCorrection() if self.params.augment else None,
            OneHotLabels(n_classes=4),
        ]

        dataset = dataset.map(
            map_func=lambda x, y, mean, stdev: apply_transforms(
                x, y, mean, stdev, transforms=transforms),
            num_parallel_calls=tf.data.experimental.AUTOTUNE)

        dataset = dataset.batch(batch_size=self._batch_size,
                                drop_remainder=True)

        dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

        if self._batch_size == 1:
            options = dataset.options()
            options.experimental_optimization.map_and_batch_fusion = False
            dataset = dataset.with_options(options)

        return dataset