def eval_fn(self): """ Create dataset for evaluation """ dataset = tf.data.TFRecordDataset(filenames=self._eval) assert len( self._eval ) > 0, "Evaluation data not found. Did you specify --fold flag?" dataset = dataset.cache() dataset = dataset.map(self.parse, num_parallel_calls=tf.data.experimental.AUTOTUNE) transforms = [ CenterCrop((224, 224, 155)), Cast(dtype=tf.float32), NormalizeImages(), OneHotLabels(n_classes=4), PadXYZ() ] dataset = dataset.map( map_func=lambda x, y, mean, stdev: apply_transforms( x, y, mean, stdev, transforms=transforms), num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.batch(batch_size=self._batch_size, drop_remainder=False) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) return dataset
def train_fn(self): assert len(self._train) > 0, "Training data not found." ds = tf.data.TFRecordDataset(filenames=self._train) ds = ds.shard(hvd.size(), hvd.rank()) ds = ds.cache() ds = ds.shuffle(buffer_size=self._batch_size * 8, seed=self._seed) ds = ds.repeat() ds = ds.map(self.parse, num_parallel_calls=tf.data.experimental.AUTOTUNE) transforms = [ RandomCrop3D((128, 128, 128)), RandomHorizontalFlip() if self.params.augment else None, Cast(dtype=tf.float32), NormalizeImages(), RandomBrightnessCorrection() if self.params.augment else None, OneHotLabels(n_classes=4), ] ds = ds.map(map_func=lambda x, y, mean, stdev: apply_transforms(x, y, mean, stdev, transforms=transforms), num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size=self._batch_size, drop_remainder=True) ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) return ds
def synth_train_fn(self): """Synthetic data function for testing""" inputs = tf.random_uniform(self._xshape, dtype=tf.int32, minval=0, maxval=255, seed=self._seed, name='synth_inputs') masks = tf.random_uniform(self._yshape, dtype=tf.int32, minval=0, maxval=4, seed=self._seed, name='synth_masks') ds = tf.data.Dataset.from_tensors((inputs, masks)) ds = ds.repeat() transforms = [ Cast(dtype=tf.uint8), RandomCrop3D((128, 128, 128)), RandomHorizontalFlip() if self.params.augment else None, Cast(dtype=tf.float32), NormalizeImages(), RandomBrightnessCorrection() if self.params.augment else None, OneHotLabels(n_classes=4), ] ds = ds.map(map_func=lambda x, y: apply_transforms(x, y, transforms), num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(self._batch_size) ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) return ds
def train_fn(self): """ Create dataset for training """ if 'debug' in self.params.exec_mode: return self.synth_train_fn() assert len(self._train) > 0, "Training data not found." dataset = tf.data.TFRecordDataset(filenames=self._train) dataset = dataset.shard(hvd.size(), hvd.rank()) dataset = dataset.cache() dataset = dataset.shuffle(buffer_size=self._batch_size * 8, seed=self._seed) dataset = dataset.repeat() dataset = dataset.map(self.parse, num_parallel_calls=tf.data.experimental.AUTOTUNE) transforms = [ RandomCrop3D(self._input_shape), RandomHorizontalFlip() if self.params.augment else None, Cast(dtype=tf.float32), NormalizeImages(), RandomBrightnessCorrection() if self.params.augment else None, OneHotLabels(n_classes=4), ] dataset = dataset.map( map_func=lambda x, y, mean, stdev: apply_transforms( x, y, mean, stdev, transforms=transforms), num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.batch(batch_size=self._batch_size, drop_remainder=True) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) if self._batch_size == 1: options = dataset.options() options.experimental_optimization.map_and_batch_fusion = False dataset = dataset.with_options(options) return dataset