Example #1
0
    def _get_session_config(mode, use_xla):

        if mode not in ["train", 'validation', 'benchmark']:
            raise ValueError(
                "Unknown mode received: %s (allowed: 'train', 'validation', 'benchmark')"
                % mode)

        config = tf.ConfigProto()

        config.allow_soft_placement = True
        config.log_device_placement = False

        config.gpu_options.allow_growth = True

        if hvd_utils.is_using_hvd():
            config.gpu_options.visible_device_list = str(hvd.local_rank())

        if use_xla:
            LOGGER.log("XLA is activated - Experimental Feature")
            config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

        config.gpu_options.force_gpu_compatible = True  # Force pinned memory

        if mode == 'train':
            config.intra_op_parallelism_threads = 1  # Avoid pool of Eigen threads

            if hvd_utils.is_using_hvd():
                config.inter_op_parallelism_threads = max(
                    2, (multiprocessing.cpu_count() // hvd.size()) - 2)
            else:
                config.inter_op_parallelism_threads = 4

        return config
Example #2
0
def get_tfrecords_input_fn(filenames, batch_size, height, width, training,
                           distort_color, num_threads, deterministic):

    shuffle_buffer_size = 4096

    if deterministic:
        if hvd_utils.is_using_hvd():
            seed = 13 * (1 + hvd.rank())
        else:
            seed = 13
    else:
        seed = None

    ds = tf.data.Dataset.from_tensor_slices(filenames)

    if hvd_utils.is_using_hvd() and training:
        ds = ds.shard(hvd.size(), hvd.rank())

    ds = ds.interleave(tf.data.TFRecordDataset,
                       cycle_length=10,
                       block_length=8)

    def preproc_func(record):
        return image_processing.preprocess_image_record(
            record, height, width, _NUM_CHANNELS, training)

    if training:
        ds = ds.shuffle(buffer_size=shuffle_buffer_size, seed=seed)

    ds = ds.repeat().map(preproc_func, num_parallel_calls=num_threads)
    ds = ds.batch(batch_size=batch_size, drop_remainder=True)
    ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return ds
Example #3
0
    def _get_run_config(mode,
                        model_dir,
                        use_xla,
                        use_dali,
                        use_cpu,
                        gpu_memory_fraction,
                        gpu_id=0,
                        seed=None):

        if mode not in ["train", 'validation', 'benchmark', 'inference']:
            raise ValueError(
                "Unknown mode received: %s (allowed: 'train', 'validation', 'benchmark', 'inference')"
                % mode)

        if seed is not None:
            if hvd_utils.is_using_hvd():
                tf_random_seed = 2 * (seed + hvd.rank())
            else:
                tf_random_seed = 2 * seed
        else:
            tf_random_seed = None

        config = tf.estimator.RunConfig(
            model_dir=model_dir,
            tf_random_seed=tf_random_seed,
            save_summary_steps=100 if mode in ['train', 'validation'] else
            1e9,  # disabled in benchmark mode
            save_checkpoints_steps=None,
            save_checkpoints_secs=None,
            session_config=Runner._get_session_config(
                mode=mode,
                use_xla=use_xla,
                use_dali=use_dali,
                use_cpu=use_cpu,
                gpu_memory_fraction=gpu_memory_fraction,
                gpu_id=gpu_id),
            keep_checkpoint_max=5,
            keep_checkpoint_every_n_hours=1e6,  # disabled
            log_step_count_steps=1e9,
            train_distribute=None,
            device_fn=None,
            protocol=None,
            eval_distribute=None,
            experimental_distribute=None)

        if mode == 'train':
            if hvd_utils.is_using_hvd():
                config = config.replace(
                    save_checkpoints_steps=1000 if hvd.rank() == 0 else None,
                    keep_checkpoint_every_n_hours=3)
            else:
                config = config.replace(save_checkpoints_steps=1000,
                                        keep_checkpoint_every_n_hours=3)

        return config
Example #4
0
def get_tfrecords_input_fn(filenames, batch_size, height, width, training,
                           distort_color, num_threads, deterministic):

    shuffle_buffer_size = 4096

    if deterministic:
        if hvd_utils.is_using_hvd():
            seed = 13 * (1 + hvd.rank())
        else:
            seed = 13
    else:
        seed = None

    ds = tf.data.Dataset.from_tensor_slices(filenames)

    if hvd_utils.is_using_hvd() and training:
        ds = ds.shard(hvd.size(), hvd.rank())

    ds = ds.apply(
        tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset,
                                                 cycle_length=10,
                                                 block_length=8,
                                                 sloppy=not deterministic,
                                                 prefetch_input_elements=16))

    counter = tf.data.Dataset.range(sys.maxsize)
    ds = tf.data.Dataset.zip((ds, counter))

    def preproc_func(record, counter_):
        return image_processing.preprocess_image_record(
            record, height, width, _NUM_CHANNELS, training)

    ds = ds.cache()

    if training:

        ds = ds.apply(
            tf.data.experimental.shuffle_and_repeat(
                buffer_size=shuffle_buffer_size, seed=seed))

    else:
        ds = ds.repeat()

    ds = ds.apply(
        tf.data.experimental.map_and_batch(
            map_func=preproc_func,
            num_parallel_calls=num_threads,
            batch_size=batch_size,
            drop_remainder=True,
        ))

    ds = ds.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)

    return ds
Example #5
0
    def _get_session_config(mode,
                            use_xla,
                            use_dali,
                            gpu_memory_fraction,
                            gpu_id=0):

        if mode not in ["train", 'validation', 'benchmark', 'inference']:
            raise ValueError(
                "Unknown mode received: %s (allowed: 'train', 'validation', 'benchmark', 'inference')"
                % mode)

        # Limit available GPU memory (tune the size)
        if use_dali:
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=gpu_memory_fraction)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.gpu_options.allow_growth = False
        else:
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True

        config.allow_soft_placement = True
        config.log_device_placement = False

        config.gpu_options.visible_device_list = str(gpu_id)

        if hvd_utils.is_using_hvd():
            config.gpu_options.visible_device_list = str(hvd.local_rank())

        if use_xla:
            config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

        config.gpu_options.force_gpu_compatible = True  # Force pinned memory

        # Bug - disable bn+relu fusion
        from tensorflow.core.protobuf import rewriter_config_pb2
        config.graph_options.rewrite_options.remapping = (
            rewriter_config_pb2.RewriterConfig.OFF)

        if mode == 'train':
            config.intra_op_parallelism_threads = 1  # Avoid pool of Eigen threads

            if hvd_utils.is_using_hvd():
                config.inter_op_parallelism_threads = max(
                    2, (multiprocessing.cpu_count() // hvd.size()) - 2)
            else:
                config.inter_op_parallelism_threads = 4

        return config
    def _get_session_config(mode,
                            use_xla,
                            use_dali,
                            gpu_memory_fraction,
                            gpu_id=0):

        if mode not in ["train", 'validation', 'benchmark', 'inference']:
            raise ValueError(
                "Unknown mode received: %s (allowed: 'train', 'validation', 'benchmark', 'inference')"
                % mode)

        # Limit available GPU memory (tune the size)
        if use_dali:
            LOGGER.log(
                "DALI is activated, GPU memory fraction used for training is limited to",
                gpu_memory_fraction)
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=gpu_memory_fraction)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.gpu_options.allow_growth = False

        else:
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True

        config.allow_soft_placement = True
        config.log_device_placement = False

        config.gpu_options.visible_device_list = str(gpu_id)

        if hvd_utils.is_using_hvd():
            config.gpu_options.visible_device_list = str(hvd.local_rank())

        if use_xla:
            LOGGER.log("XLA is activated - Experimental Feature")
            config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

        config.gpu_options.force_gpu_compatible = True  # Force pinned memory

        if mode == 'train':
            config.intra_op_parallelism_threads = 1  # Avoid pool of Eigen threads

            if hvd_utils.is_using_hvd():
                config.inter_op_parallelism_threads = max(
                    2, (multiprocessing.cpu_count() // hvd.size()) - 2)
            else:
                config.inter_op_parallelism_threads = 4

        return config
        def _augment_data(input_image, mask_image, label):

            if augment_data:

                if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                    print("Using data augmentation ...")

                #input_image = tf.image.per_image_standardization(input_image)

                horizontal_flip = tf.random_uniform(shape=(), seed=seed) > 0.5
                input_image = tf.cond(
                    horizontal_flip,
                    lambda: tf.image.flip_left_right(input_image),
                    lambda: input_image)
                mask_image = tf.cond(
                    horizontal_flip,
                    lambda: tf.image.flip_left_right(mask_image),
                    lambda: mask_image)

                n_rots = tf.random_uniform(shape=(),
                                           dtype=tf.int32,
                                           minval=0,
                                           maxval=3,
                                           seed=seed)
                input_image = tf.image.rot90(input_image, k=n_rots)
                mask_image = tf.image.rot90(mask_image, k=n_rots)

            return (input_image, mask_image), label
Example #8
0
        def training_data_fn():

            if not is_benchmark or self.run_hparams.data_dir is not None:

                return self.dataset.dataset_fn(
                    batch_size=batch_size,
                    training=True,
                    only_defective_images=True,
                    augment_data=augment_data,
                    input_shape=list(self.run_hparams.input_shape) +
                    [self.run_hparams.n_channels],
                    mask_shape=list(self.run_hparams.mask_shape) +
                    [self.run_hparams.n_channels],
                    num_threads=64,
                    use_gpu_prefetch=True,
                    normalize_data_method="zero_centered",
                    seed=self.run_hparams.seed)

            else:
                if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                    print("Using Synthetic Data ...")

                return self.dataset.synth_dataset_fn(
                    batch_size=batch_size,
                    training=True,
                    input_shape=list(self.run_hparams.input_shape) +
                    [self.run_hparams.n_channels],
                    mask_shape=list(self.run_hparams.mask_shape) +
                    [self.run_hparams.n_channels],
                    num_threads=64,
                    use_gpu_prefetch=True,
                    normalize_data_method="zero_centered",
                    only_defective_images=True,
                    augment_data=augment_data,
                    seed=self.run_hparams.seed)
    def after_run(self, run_context, run_values):
        self.global_step = run_values.results[0] + 1

        if is_using_hvd() and len(run_values.results) == 2:
            if run_values.results[1] > 0:
                run_context.request_stop()
        elif self.signal_recieved:
            run_context.request_stop()
 def begin(self):
     if is_using_hvd():
         with tf.device("/cpu:0"):
             self.input_op = tf.placeholder(tf.int32, shape=())
             self.allreduce_op = hvd.allreduce(
                 self.input_op,
                 op=hvd.Sum,
                 name="signal_handler_all_reduce")
    def before_run(self, run_context):
        fetches = [tf.train.get_global_step()]
        feed_dict = None

        if is_using_hvd() and (self.global_step % self.sync_freq) == 0:
            fetches += [self.allreduce_op]
            feed_dict = {self.input_op: int(self.signal_recieved)}

        return tf.train.SessionRunArgs(fetches, feed_dict=feed_dict)
Example #12
0
def _log_hparams(classname, layername, **kwargs):

    log_msg = "%s: `%s`" % (classname, layername)

    for arg, val in sorted(kwargs.items()):
        log_msg += "\n\t[*] {}: {}".format(arg, val)

    log_msg += "\n"

    if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
        print(log_msg)
    def after_run(self, run_context, run_values):
        self.global_step = run_values.results[0]

        if self.should_exit:
            run_context.request_stop()
            return

        if is_using_hvd() and len(run_values.results) == 3:
            self.should_exit = (run_values.results[2][0] == hvd.size())
        else:
            self.should_exit = self.signal_recieved
    def before_run(self, run_context):
        fetches = [tf.train.get_global_step()]

        if is_using_hvd():
            fetches.append("signal_handler_var_set:0" if self.
                           signal_recieved else "signal_handler_var:0")

            if self.should_exit:
                fetches.append("signal_handler_var_reset:0")
            elif self.signal_recieved:
                fetches.append("signal_handler_var_set:0")
            else:
                fetches.append("signal_handler_var:0")

            if ((self.global_step % self.sync_freq)
                    == 0) and not self.should_exit:
                fetches.append("signal_handler_all_reduce:0")

        run_args = tf.train.SessionRunArgs(fetches)
        return run_args
Example #15
0
    def __init__(
        self,
        # ========= Model HParams ========= #
        n_classes=1001,
        input_format='NHWC',    # NCHW or NHWC
        compute_format='NCHW',  # NCHW or NHWC
        dtype=tf.float32,       # tf.float32 or tf.float16
        n_channels=3,
        height=224,
        width=224,
        distort_colors=False,
        model_dir=None,
        log_dir=None,
        data_dir=None,
        data_idx_dir=None,

        # ======= Optimization HParams ======== #
        use_xla=False,
        use_tf_amp=False,
        use_dali=False,
        gpu_memory_fraction=1.0,
        
        # ======== Debug Flags ======== #
        debug_verbosity=0,
        seed=None
    ):

        if dtype not in [tf.float32, tf.float16]:
            raise ValueError("Unknown dtype received: %s (allowed: `tf.float32` and `tf.float16`)" % dtype)

        if compute_format not in ["NHWC", 'NCHW']:
            raise ValueError("Unknown `compute_format` received: %s (allowed: ['NHWC', 'NCHW'])" % compute_format)

        if input_format not in ["NHWC", 'NCHW']:
            raise ValueError("Unknown `input_format` received: %s (allowed: ['NHWC', 'NCHW'])" % input_format)

        if n_channels not in [1, 3]:
            raise ValueError("Unsupported number of channels: %d (allowed: 1 (grayscale) and 3 (color))" % n_channels)

        if data_dir is not None and not os.path.exists(data_dir):
            raise ValueError("The `data_dir` received does not exists: %s" % data_dir)

        hvd.init()
        tf_seed = 2 * (seed + hvd.rank()) if seed is not None else None

        # ============================================
        # Optimsation Flags - Do not remove
        # ============================================

        os.environ['CUDA_CACHE_DISABLE'] = '0'

        os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL'

        #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

        os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
        os.environ['TF_GPU_THREAD_COUNT'] = '1' if not hvd_utils.is_using_hvd() else str(hvd.size())

        os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'

        os.environ['TF_ADJUST_HUE_FUSED'] = '1'
        os.environ['TF_ADJUST_SATURATION_FUSED'] = '1'
        os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

        os.environ['TF_SYNC_ON_FINISH'] = '0'
        os.environ['TF_AUTOTUNE_THRESHOLD'] = '2'
        os.environ['TF_DISABLE_NVTX_RANGES'] = '1'

        # ============================================
        # TF-AMP Setup - Do not remove
        # ============================================

        if dtype == tf.float16:

            if use_tf_amp:
                raise RuntimeError("TF AMP can not be activated for FP16 precision")

        elif use_tf_amp:
            
            if hvd.rank() == 0:
                LOGGER.log("TF AMP is activated - Experimental Feature")
            os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"

        # =================================================

        model_hparams = tf.contrib.training.HParams(
            width=height,
            height=width,
            n_channels=n_channels,
            n_classes=n_classes,
            dtype=dtype,
            input_format=input_format,
            compute_format=compute_format,
            distort_colors=distort_colors,
            seed=tf_seed
        )
        
        if use_dali:
            num_preprocessing_threads=4
        else:
            num_preprocessing_threads=10

        run_config_performance = tf.contrib.training.HParams(
            num_preprocessing_threads=num_preprocessing_threads,
            use_tf_amp=use_tf_amp,
            use_xla=use_xla,
            use_dali=use_dali,
            gpu_memory_fraction=gpu_memory_fraction
        )

        run_config_additional = tf.contrib.training.HParams(
            model_dir=model_dir if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            log_dir=log_dir if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            data_dir=data_dir,
            data_idx_dir=data_idx_dir,
            num_preprocessing_threads=num_preprocessing_threads
        )

        self.run_hparams = Runner._build_hparams(model_hparams, run_config_additional, run_config_performance)

        self._model = resnet_v1_5.ResnetModel(
            model_name="resnet50_v1.5",
            n_classes=model_hparams.n_classes,
            input_format=model_hparams.input_format,
            compute_format=model_hparams.compute_format,
            dtype=model_hparams.dtype,
            use_dali=use_dali
        )

        if self.run_hparams.seed is not None:
            if hvd.rank() == 0:
                LOGGER.log("Deterministic Run - Seed: %d" % seed)
            tf.set_random_seed(self.run_hparams.seed)
Example #16
0
    def train(
        self,
        iter_unit,
        num_iter,
        batch_size,
        warmup_steps=50,
        weight_decay=1e-4,
        lr_init=0.1,
        lr_warmup_epochs=5,
        momentum=0.9,
        log_every_n_steps=1,
        loss_scale=256,
        label_smoothing=0.0,
        use_cosine_lr=False,
        use_static_loss_scaling=False,
        is_benchmark=False
    ):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError('`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])' % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for training!')

        if self.run_hparams.use_tf_amp or self.run_hparams.dtype == tf.float16:
            if use_static_loss_scaling:
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "0"
            else:
                LOGGER.log("TF Loss Auto Scaling is activated")
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"
        else:
            use_static_loss_scaling = False  # Make sure it hasn't been set to True on FP32 training

        num_gpus = 1 if not hvd_utils.is_using_hvd() else hvd.size()
        global_batch_size = batch_size * num_gpus

        if self.run_hparams.data_dir is not None:
            filenames,num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
                data_dir=self.run_hparams.data_dir,
                mode="train",
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=global_batch_size,
            )

            steps_per_epoch = num_steps / num_epochs

        else:
            num_epochs = 1
            num_steps = num_iter
            steps_per_epoch = num_steps
            num_decay_steps = num_steps
            num_samples = num_steps * batch_size

            
        if self.run_hparams.data_idx_dir is not None:
            idx_filenames = runner_utils.parse_dali_idx_dataset(
                data_idx_dir=self.run_hparams.data_idx_dir,
                mode="train"
            )
            
        training_hooks = []
      
        if hvd.rank() == 0:
            LOGGER.log('Starting Model Training...')
            LOGGER.log("Training Epochs", num_epochs)
            LOGGER.log("Total Steps", num_steps)
            LOGGER.log("Steps per Epoch", steps_per_epoch)
            LOGGER.log("Decay Steps", num_decay_steps)
            LOGGER.log("Weight Decay Factor", weight_decay)
            LOGGER.log("Init Learning Rate", lr_init)
            LOGGER.log("Momentum", momentum)
            LOGGER.log("Num GPUs", num_gpus)
            LOGGER.log("Per-GPU Batch Size", batch_size)

            
            if is_benchmark:

                benchmark_logging_hook = hooks.BenchmarkLoggingHook(
                    log_file_path=os.path.join(self.run_hparams.log_dir, "training_benchmark.json"),
                    global_batch_size=global_batch_size,
                    log_every=log_every_n_steps,
                    warmup_steps=warmup_steps
                )

                training_hooks.append(benchmark_logging_hook)

            else:

                training_logging_hook = hooks.TrainingLoggingHook(
                    log_file_path=os.path.join(self.run_hparams.log_dir, "training.json"),
                    global_batch_size=global_batch_size,
                    num_steps=num_steps,
                    num_samples=num_samples,
                    num_epochs=num_epochs,
                    log_every=log_every_n_steps
                )

                training_hooks.append(training_logging_hook)

        if hvd_utils.is_using_hvd():
            bcast_hook = hvd.BroadcastGlobalVariablesHook(0)
            training_hooks.append(bcast_hook)

        training_hooks.append(hooks.PrefillStagingAreasHook())

        # NVTX
        nvtx_callback = NVTXHook(skip_n_steps=1, name='Train')
        training_hooks.append(nvtx_callback)
      
        estimator_params = {
            'batch_size': batch_size,
            'steps_per_epoch': steps_per_epoch,
            'num_gpus': num_gpus,
            'momentum': momentum,
            'lr_init': lr_init,
            'lr_warmup_epochs': lr_warmup_epochs,
            'weight_decay': weight_decay,
            'loss_scale': loss_scale,
            'apply_loss_scaling': use_static_loss_scaling,
            'label_smoothing': label_smoothing,
            'num_decay_steps': num_decay_steps,
            'use_cosine_lr': use_cosine_lr
        }

        image_classifier = self._get_estimator(
            mode='train',
            run_params=estimator_params,
            use_xla=self.run_hparams.use_xla,
            use_dali=self.run_hparams.use_dali,
            gpu_memory_fraction=self.run_hparams.gpu_memory_fraction
        )

        def training_data_fn():
            
            if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
                if hvd.rank() == 0:
                    LOGGER.log("Using DALI input... ")
                    
                return data_utils.get_dali_input_fn(
                    filenames=filenames,
                    idx_filenames=idx_filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False if self.run_hparams.seed is None else True
                )
            
            elif self.run_hparams.data_dir is not None:

                return data_utils.get_tfrecords_input_fn(
                    filenames=filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False if self.run_hparams.seed is None else True
                )

            else:
                if hvd.rank() == 0:
                    LOGGER.log("Using Synthetic Data ...")
                return data_utils.get_synth_input_fn(
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    num_classes=self.run_hparams.n_classes,
                    dtype=self.run_hparams.dtype,
                )


        try:
            image_classifier.train(
                input_fn=training_data_fn,
                steps=num_steps,
                hooks=training_hooks,
            )
        except KeyboardInterrupt:
            print("Keyboard interrupt")
            
        if hvd.rank() == 0:
            LOGGER.log('Ending Model Training ...')
Example #17
0
    def __call__(self, features, labels, mode, params):

        if mode == tf.estimator.ModeKeys.TRAIN:
            mandatory_params = ["batch_size", "lr_init", "num_gpus", "steps_per_epoch",
                                "momentum", "weight_decay", "loss_scale", "label_smoothing"]
            for p in mandatory_params:
                if p not in params:
                    raise RuntimeError("Parameter {} is missing.".format(p))

        if mode == tf.estimator.ModeKeys.TRAIN and not self.model_hparams.use_dali:

            with tf.device('/cpu:0'):
                # Stage inputs on the host
                cpu_prefetch_op, (features, labels) = self._stage([features, labels])

            with tf.device('/gpu:0'):
                # Stage inputs to the device
                gpu_prefetch_op, (features, labels) = self._stage([features, labels])

        with tf.device("/gpu:0"):

            if features.dtype != self.model_hparams.dtype:
                features = tf.cast(features, self.model_hparams.dtype)

            # Subtract mean per channel
            # and enforce values between [-1, 1]
            if not self.model_hparams.use_dali:
                features = normalized_inputs(features)

            mixup = 0
            eta = 0
            
            if mode == tf.estimator.ModeKeys.TRAIN:        
                eta = params['label_smoothing']
                mixup = params['mixup']
                
            if mode != tf.estimator.ModeKeys.PREDICT: 
                one_hot_smoothed_labels = tf.one_hot(labels, 1001, 
                                                     on_value = 1 - eta + eta/1001,
                                                     off_value = eta/1001)
                if mixup != 0:

                    print("Using mixup training with beta=", params['mixup'])
                    beta_distribution = tf.distributions.Beta(params['mixup'], params['mixup'])

                    feature_coefficients = beta_distribution.sample(sample_shape=[params['batch_size'], 1, 1, 1])      

                    reversed_feature_coefficients = tf.subtract(tf.ones(shape=feature_coefficients.shape), feature_coefficients)

                    rotated_features = tf.reverse(features, axis=[0])      

                    features = feature_coefficients * features + reversed_feature_coefficients * rotated_features

                    label_coefficients = tf.squeeze(feature_coefficients, axis=[2, 3])

                    rotated_labels = tf.reverse(one_hot_smoothed_labels, axis=[0])    

                    reversed_label_coefficients = tf.subtract(tf.ones(shape=label_coefficients.shape), label_coefficients)

                    one_hot_smoothed_labels = label_coefficients * one_hot_smoothed_labels + reversed_label_coefficients * rotated_labels
                
                
            # Update Global Step
            global_step = tf.train.get_or_create_global_step()
            tf.identity(global_step, name="global_step_ref")

            tf.identity(features, name="features_ref")
            
            if mode == tf.estimator.ModeKeys.TRAIN:
                tf.identity(labels, name="labels_ref")

            probs, logits = self.build_model(
                features,
                training=mode == tf.estimator.ModeKeys.TRAIN,
                reuse=False
            )

            y_preds = tf.argmax(logits, axis=1, output_type=tf.int32)

            # Check the output dtype, shall be FP32 in training
            assert (probs.dtype == tf.float32)
            assert (logits.dtype == tf.float32)
            assert (y_preds.dtype == tf.int32)

            tf.identity(logits, name="logits_ref")
            tf.identity(probs, name="probs_ref")
            tf.identity(y_preds, name="y_preds_ref")

            #if mode == tf.estimator.ModeKeys.TRAIN:
            #    
            #    assert (len(tf.trainable_variables()) == 161)
            #
            #else:
            #    
            #    assert (len(tf.trainable_variables()) == 0)


        if mode == tf.estimator.ModeKeys.PREDICT:

            predictions = {'classes': y_preds, 'probabilities': probs}

            return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs={'predict': tf.estimator.export.PredictOutput(predictions)}
            )

        else:

            with tf.device("/gpu:0"):

                if mode == tf.estimator.ModeKeys.TRAIN:
                    acc_top1 = tf.nn.in_top_k(predictions=logits, targets=labels, k=1)
                    acc_top5 = tf.nn.in_top_k(predictions=logits, targets=labels, k=5)

                else:
                    acc_top1, acc_top1_update_op = tf.metrics.mean(tf.nn.in_top_k(predictions=logits, targets=labels, k=1))
                    acc_top5, acc_top5_update_op = tf.metrics.mean(tf.nn.in_top_k(predictions=logits, targets=labels, k=5))

                tf.identity(acc_top1, name="acc_top1_ref")
                tf.identity(acc_top5, name="acc_top5_ref")

                predictions = {
                    'classes': y_preds,
                    'probabilities': probs,
                    'accuracy_top1': acc_top1,
                    'accuracy_top5': acc_top5
                }
                
                cross_entropy = tf.losses.softmax_cross_entropy(
                    logits=logits, onehot_labels=one_hot_smoothed_labels)

                assert (cross_entropy.dtype == tf.float32)
                tf.identity(cross_entropy, name='cross_entropy_loss_ref')

                def loss_filter_fn(name):
                    """we don't need to compute L2 loss for BN and bias (eq. to add a cste)"""
                    return all([
                        tensor_name not in name.lower()
                        # for tensor_name in ["batchnorm", "batch_norm", "batch_normalization", "bias"]
                        for tensor_name in ["batchnorm", "batch_norm", "batch_normalization"]
                    ])

                filtered_params = [tf.cast(v, tf.float32) for v in tf.trainable_variables() if loss_filter_fn(v.name)]

                if len(filtered_params) != 0:

                    l2_loss_per_vars = [tf.nn.l2_loss(v) for v in filtered_params]
                    l2_loss = tf.multiply(tf.add_n(l2_loss_per_vars), params["weight_decay"])

                else:
                    l2_loss = tf.zeros(shape=(), dtype=tf.float32)

                assert (l2_loss.dtype == tf.float32)
                tf.identity(l2_loss, name='l2_loss_ref')

                total_loss = tf.add(cross_entropy, l2_loss, name="total_loss")

                assert (total_loss.dtype == tf.float32)
                tf.identity(total_loss, name='total_loss_ref')

                tf.summary.scalar('cross_entropy', cross_entropy)
                tf.summary.scalar('l2_loss', l2_loss)
                tf.summary.scalar('total_loss', total_loss)
                
                if mode == tf.estimator.ModeKeys.TRAIN:

                    with tf.device("/cpu:0"):

                        learning_rate = learning_rate_scheduler(
                            lr_init=params["lr_init"],
                            lr_warmup_epochs=params["lr_warmup_epochs"],
                            global_step=global_step,
                            batch_size=params["batch_size"],
                            num_batches_per_epoch=params["steps_per_epoch"],
                            num_decay_steps=params["num_decay_steps"],
                            num_gpus=params["num_gpus"],
                            use_cosine_lr=params["use_cosine_lr"]
                        )

                    tf.identity(learning_rate, name='learning_rate_ref')
                    tf.summary.scalar('learning_rate', learning_rate)

                    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=params["momentum"])

                    if params["apply_loss_scaling"]:
                        optimizer = FixedLossScalerOptimizer(optimizer, scale=params["loss_scale"])

                    if hvd_utils.is_using_hvd():
                        optimizer = hvd.DistributedOptimizer(optimizer)

                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    if mode != tf.estimator.ModeKeys.TRAIN:
                        update_ops += [acc_top1_update_op, acc_top5_update_op]
                    
                    deterministic = True
                    gate_gradients = (tf.train.Optimizer.GATE_OP if deterministic else tf.train.Optimizer.GATE_NONE)

                    backprop_op = optimizer.minimize(total_loss, gate_gradients=gate_gradients, global_step=global_step)

                    
                    if self.model_hparams.use_dali:
                        train_ops = tf.group(backprop_op, update_ops, name='train_ops')
                    else:
                        train_ops = tf.group(backprop_op, cpu_prefetch_op, gpu_prefetch_op, update_ops, name='train_ops')

                    return tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_ops)

                elif mode == tf.estimator.ModeKeys.EVAL:
                    eval_metrics = {
                        "top1_accuracy": (acc_top1, acc_top1_update_op),
                        "top5_accuracy": (acc_top5, acc_top5_update_op)
                    }

                    return tf.estimator.EstimatorSpec(
                        mode=mode,
                        predictions=predictions,
                        loss=total_loss,
                        eval_metric_ops=eval_metrics
                    )

                else:
                    raise NotImplementedError('Unknown mode {}'.format(mode))
Example #18
0
        runner.train(
            iter_unit=RUNNING_CONFIG.iter_unit,
            num_iter=RUNNING_CONFIG.num_iter,
            batch_size=RUNNING_CONFIG.batch_size,
            warmup_steps=RUNNING_CONFIG.warmup_steps,
            log_every_n_steps=RUNNING_CONFIG.log_every_n_steps,
            weight_decay=RUNNING_CONFIG.weight_decay,
            learning_rate_init=RUNNING_CONFIG.learning_rate_init,
            momentum=RUNNING_CONFIG.momentum,
            loss_scale=RUNNING_CONFIG.loss_scale,
            use_static_loss_scaling=FLAGS.use_static_loss_scaling,
            is_benchmark=RUNNING_CONFIG.mode == 'training_benchmark',
        )

    if RUNNING_CONFIG.mode in ["train_and_evaluate", 'evaluate', 'inference_benchmark']:

        if RUNNING_CONFIG.mode == 'inference_benchmark' and hvd_utils.is_using_hvd():
            raise NotImplementedError("Only single GPU inference is implemented.")

        elif not hvd_utils.is_using_hvd() or hvd.rank() == 0:

            runner.evaluate(
                iter_unit=RUNNING_CONFIG.iter_unit if RUNNING_CONFIG.mode != "train_and_evaluate" else "epoch",
                num_iter=RUNNING_CONFIG.num_iter if RUNNING_CONFIG.mode != "train_and_evaluate" else 1,
                warmup_steps=RUNNING_CONFIG.warmup_steps,
                batch_size=RUNNING_CONFIG.batch_size,
                log_every_n_steps=RUNNING_CONFIG.log_every_n_steps,
                is_benchmark=RUNNING_CONFIG.mode == 'inference_benchmark'
            )
Example #19
0
    def train(self,
              iter_unit,
              num_iter,
              batch_size,
              weight_decay,
              learning_rate,
              learning_rate_decay_factor,
              learning_rate_decay_steps,
              rmsprop_decay,
              rmsprop_momentum,
              use_auto_loss_scaling,
              augment_data,
              warmup_steps=50,
              is_benchmark=False):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError(
                '`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])'
                % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for training!')

        if self.run_hparams.amp:
            if use_auto_loss_scaling:

                if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                    print(
                        "TF Loss Auto Scaling is activated - Experimental Feature"
                    )

                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"
                apply_manual_loss_scaling = False

            else:
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "0"
                apply_manual_loss_scaling = True
        else:
            apply_manual_loss_scaling = False

        global_batch_size = batch_size * self.num_gpus

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs = self.dataset.get_dataset_runtime_specs(
                training=True,
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=global_batch_size)

            steps_per_epoch = int(num_steps / num_epochs)

        else:
            num_epochs = 1
            num_steps = num_iter
            steps_per_epoch = 625

        training_hooks = []

        if hvd_utils.is_using_hvd():
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
            training_hooks.append(
                ProfilerHook(global_batch_size=global_batch_size,
                             log_every=self.run_hparams.log_every_n_steps,
                             warmup_steps=warmup_steps,
                             is_training=True,
                             sample_dir=self.run_hparams.sample_dir))

            print("Starting Model Training ...")

            Logger.log(step=('PARAMETER'),
                       data={"Epochs": num_epochs},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"Total Steps": num_steps},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"Steps per Epoch": steps_per_epoch},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"Weight Decay Factor": weight_decay},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"Learning Rate": learning_rate},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={
                           "Learning Rate Decay Factor":
                           learning_rate_decay_factor
                       },
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(
                step=('PARAMETER'),
                data={"Learning Rate Decay Steps": learning_rate_decay_steps},
                verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"RMSProp - Decay": rmsprop_decay},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"RMSProp - Momentum": rmsprop_momentum},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(
                step=('PARAMETER'),
                data={"Loss Function Name": self.run_hparams.loss_fn_name},
                verbosity=Logger.Verbosity.DEFAULT)

            if self.run_hparams.amp:
                Logger.log(
                    step=('PARAMETER'),
                    data={"Use Auto Loss Scaling": use_auto_loss_scaling},
                    verbosity=Logger.Verbosity.DEFAULT)

            Logger.log(step=('PARAMETER'),
                       data={"# GPUs": self.num_gpus},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"GPU Batch Size": batch_size},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"Global Batch Size": global_batch_size},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={
                           "Total Files to be Processed":
                           num_steps * global_batch_size
                       },
                       verbosity=Logger.Verbosity.DEFAULT)

            print()  # visual spacing

        estimator_params = {
            'batch_size': batch_size,
            'steps_per_epoch': steps_per_epoch,
            'learning_rate': learning_rate,
            'learning_rate_decay_steps': learning_rate_decay_steps,
            'learning_rate_decay_factor': learning_rate_decay_factor,
            'rmsprop_decay': rmsprop_decay,
            'rmsprop_momentum': rmsprop_momentum,
            'weight_decay': weight_decay,
            'apply_manual_loss_scaling': apply_manual_loss_scaling,
            'loss_fn_name': self.run_hparams.loss_fn_name,
            'debug_verbosity': self.run_hparams.debug_verbosity,
        }

        def training_data_fn():

            if not is_benchmark or self.run_hparams.data_dir is not None:

                return self.dataset.dataset_fn(
                    batch_size=batch_size,
                    training=True,
                    only_defective_images=True,
                    augment_data=augment_data,
                    input_shape=list(self.run_hparams.input_shape) +
                    [self.run_hparams.n_channels],
                    mask_shape=list(self.run_hparams.mask_shape) +
                    [self.run_hparams.n_channels],
                    num_threads=64,
                    use_gpu_prefetch=True,
                    normalize_data_method="zero_centered",
                    seed=self.run_hparams.seed)

            else:
                if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                    print("Using Synthetic Data ...")

                return self.dataset.synth_dataset_fn(
                    batch_size=batch_size,
                    training=True,
                    input_shape=list(self.run_hparams.input_shape) +
                    [self.run_hparams.n_channels],
                    mask_shape=list(self.run_hparams.mask_shape) +
                    [self.run_hparams.n_channels],
                    num_threads=64,
                    use_gpu_prefetch=True,
                    normalize_data_method="zero_centered",
                    only_defective_images=True,
                    augment_data=augment_data,
                    seed=self.run_hparams.seed)

        model = self._get_estimator(mode='train',
                                    run_params=estimator_params,
                                    xla=self.xla)

        try:
            model.train(
                input_fn=training_data_fn,
                steps=num_steps,
                hooks=training_hooks,
            )
        except KeyboardInterrupt:
            print("Keyboard interrupt")

        if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
            print('Ending Model Training ...')
Example #20
0
    def __init__(
            self,
            # ========= Model HParams ========= #
            n_classes=1001,
            architecture='resnet50',
            input_format='NHWC',  # NCHW or NHWC
            compute_format='NCHW',  # NCHW or NHWC
            dtype=tf.float32,  # tf.float32 or tf.float16
            n_channels=3,
            height=224,
            width=224,
            distort_colors=False,
            model_dir=None,
            log_dir=None,
            data_dir=None,
            data_idx_dir=None,
            weight_init="fan_out",

            # ======= Optimization HParams ======== #
            use_xla=False,
            use_tf_amp=False,
            use_dali=False,
            gpu_memory_fraction=1.0,
            gpu_id=0,

            # ======== Debug Flags ======== #
            debug_verbosity=0,
            seed=None):

        if dtype not in [tf.float32, tf.float16]:
            raise ValueError(
                "Unknown dtype received: %s (allowed: `tf.float32` and `tf.float16`)"
                % dtype)

        if compute_format not in ["NHWC", 'NCHW']:
            raise ValueError(
                "Unknown `compute_format` received: %s (allowed: ['NHWC', 'NCHW'])"
                % compute_format)

        if input_format not in ["NHWC", 'NCHW']:
            raise ValueError(
                "Unknown `input_format` received: %s (allowed: ['NHWC', 'NCHW'])"
                % input_format)

        if n_channels not in [1, 3]:
            raise ValueError(
                "Unsupported number of channels: %d (allowed: 1 (grayscale) and 3 (color))"
                % n_channels)

        tf_seed = 2 * (seed + hvd.rank()) if seed is not None else None

        # ============================================
        # Optimsation Flags - Do not remove
        # ============================================

        os.environ['CUDA_CACHE_DISABLE'] = '0'

        os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL'

        #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

        os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
        os.environ['TF_GPU_THREAD_COUNT'] = '1' if not hvd_utils.is_using_hvd(
        ) else str(hvd.size())

        os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'

        os.environ['TF_ADJUST_HUE_FUSED'] = '1'
        os.environ['TF_ADJUST_SATURATION_FUSED'] = '1'
        os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

        os.environ['TF_SYNC_ON_FINISH'] = '0'
        os.environ['TF_AUTOTUNE_THRESHOLD'] = '2'
        os.environ['TF_DISABLE_NVTX_RANGES'] = '1'
        os.environ["TF_XLA_FLAGS"] = (
            os.environ.get("TF_XLA_FLAGS", "") +
            " --tf_xla_enable_lazy_compilation=false")

        # ============================================
        # TF-AMP Setup - Do not remove
        # ============================================

        if dtype == tf.float16:
            if use_tf_amp:
                raise RuntimeError(
                    "TF AMP can not be activated for FP16 precision")

        elif use_tf_amp:
            os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
        else:
            os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "0"

        # =================================================

        model_hparams = tf.contrib.training.HParams(
            width=height,
            height=width,
            n_channels=n_channels,
            n_classes=n_classes,
            dtype=dtype,
            input_format=input_format,
            compute_format=compute_format,
            distort_colors=distort_colors,
            seed=tf_seed)

        num_preprocessing_threads = 10 if not use_dali else 4
        run_config_performance = tf.contrib.training.HParams(
            num_preprocessing_threads=num_preprocessing_threads,
            use_tf_amp=use_tf_amp,
            use_xla=use_xla,
            use_dali=use_dali,
            gpu_memory_fraction=gpu_memory_fraction,
            gpu_id=gpu_id)

        run_config_additional = tf.contrib.training.HParams(
            model_dir=model_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            log_dir=log_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            data_dir=data_dir,
            data_idx_dir=data_idx_dir,
            num_preprocessing_threads=num_preprocessing_threads)

        self.run_hparams = Runner._build_hparams(model_hparams,
                                                 run_config_additional,
                                                 run_config_performance)

        model_name = architecture
        architecture = resnet.model_architectures[architecture]

        self._model = resnet.ResnetModel(
            model_name=model_name,
            n_classes=model_hparams.n_classes,
            layers_count=architecture["layers"],
            layers_depth=architecture["widths"],
            expansions=architecture["expansions"],
            input_format=model_hparams.input_format,
            compute_format=model_hparams.compute_format,
            dtype=model_hparams.dtype,
            weight_init=weight_init,
            use_dali=use_dali,
            cardinality=architecture['cardinality']
            if 'cardinality' in architecture else 1,
            use_se=architecture['use_se']
            if 'use_se' in architecture else False,
            se_ratio=architecture['se_ratio']
            if 'se_ratio' in architecture else 1)

        if self.run_hparams.seed is not None:
            np.random.seed(self.run_hparams.seed)
            tf.set_random_seed(self.run_hparams.seed)

        self.training_logging_hook = None
        self.eval_logging_hook = None
def build_stats(history,
                validation_output,
                train_callbacks,
                eval_callbacks,
                logger,
                comment=''):
    stats = {}
    stats['comment'] = comment
    if validation_output:
        stats['eval_loss'] = float(validation_output[0])
        stats['eval_accuracy_top_1'] = float(validation_output[1])
        stats['eval_accuracy_top_5'] = float(validation_output[2])
    #This part is train loss on GPU_0
    if history and history.history:
        train_hist = history.history
        #Gets final loss from training.
        stats['training_loss'] = float(
            hvd.allreduce(tf.constant(train_hist['loss'][-1],
                                      dtype=tf.float32),
                          average=True))
        # Gets top_1 training accuracy.
        if 'categorical_accuracy' in train_hist:
            stats['training_accuracy_top_1'] = float(
                hvd.allreduce(tf.constant(
                    train_hist['categorical_accuracy'][-1], dtype=tf.float32),
                              average=True))
        elif 'sparse_categorical_accuracy' in train_hist:
            stats['training_accuracy_top_1'] = float(
                hvd.allreduce(tf.constant(
                    train_hist['sparse_categorical_accuracy'][-1],
                    dtype=tf.float32),
                              average=True))
        elif 'accuracy' in train_hist:
            stats['training_accuracy_top_1'] = float(
                hvd.allreduce(tf.constant(train_hist['accuracy'][-1],
                                          dtype=tf.float32),
                              average=True))
            stats['training_accuracy_top_5'] = float(
                hvd.allreduce(tf.constant(train_hist['top_5_accuracy'][-1],
                                          dtype=tf.float32),
                              average=True))

    # Look for the time history callback which was used during keras.fit
    if train_callbacks:
        for callback in train_callbacks:
            if isinstance(callback, callbacks.TimeHistory):
                if callback.epoch_runtime_log:
                    stats[
                        'avg_exp_per_second_training'] = callback.average_examples_per_second
                    stats[
                        'avg_exp_per_second_training_per_GPU'] = callback.average_examples_per_second / hvd.size(
                        )

    if eval_callbacks:
        for eval_callback in eval_callbacks:
            if not isinstance(eval_callback, callbacks.EvalTimeHistory):
                continue
            stats['avg_exp_per_second_eval'] = float(
                eval_callback.average_examples_per_second
            )  # * hvd.size(), performing one-gpu evluation now
            stats['avg_exp_per_second_eval_per_GPU'] = float(
                eval_callback.average_examples_per_second)
            stats['avg_time_per_exp_eval'] = 1000. / stats[
                'avg_exp_per_second_eval']
            batch_time = eval_callback.batch_time
            batch_time.sort()
            latency_pct_per_batch = sum(
                batch_time[:-1]) / int(len(batch_time) - 1)
            stats['latency_pct'] = 1000.0 * latency_pct_per_batch
            latency_90pct_per_batch = sum(
                batch_time[:int(0.9 * len(batch_time))]) / int(
                    0.9 * len(batch_time))
            stats['latency_90pct'] = 1000.0 * latency_90pct_per_batch
            latency_95pct_per_batch = sum(
                batch_time[:int(0.95 * len(batch_time))]) / int(
                    0.95 * len(batch_time))
            stats['latency_95pct'] = 1000.0 * latency_95pct_per_batch
            latency_99pct_per_batch = sum(
                batch_time[:int(0.99 * len(batch_time))]) / int(
                    0.99 * len(batch_time))
            stats['latency_99pct'] = 1000.0 * latency_99pct_per_batch

    if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
        logger.log(step=(), data=stats)
Example #22
0
                     loss_scale=FLAGS.static_loss_scale,
                     label_smoothing=FLAGS.label_smoothing,
                     mixup=FLAGS.mixup,
                     use_static_loss_scaling=(FLAGS.static_loss_scale != -1),
                     use_cosine_lr=FLAGS.cosine_lr,
                     is_benchmark=FLAGS.mode == 'training_benchmark',
                     use_final_conv=FLAGS.use_final_conv,
                     quantize=FLAGS.quantize,
                     symmetric=FLAGS.symmetric,
                     quant_delay=FLAGS.quant_delay,
                     use_qdq=FLAGS.use_qdq,
                     finetune_checkpoint=FLAGS.finetune_checkpoint)

    if FLAGS.mode in ["train_and_evaluate", 'evaluate', 'inference_benchmark']:

        if FLAGS.mode == 'inference_benchmark' and hvd_utils.is_using_hvd():
            raise NotImplementedError(
                "Only single GPU inference is implemented.")

        elif not hvd_utils.is_using_hvd() or hvd.rank() == 0:

            runner.evaluate(iter_unit=FLAGS.iter_unit
                            if FLAGS.mode != "train_and_evaluate" else "epoch",
                            num_iter=FLAGS.num_iter
                            if FLAGS.mode != "train_and_evaluate" else 1,
                            warmup_steps=FLAGS.warmup_steps,
                            batch_size=FLAGS.batch_size,
                            log_every_n_steps=FLAGS.display_every,
                            is_benchmark=FLAGS.mode == 'inference_benchmark',
                            export_dir=FLAGS.export_dir,
                            quantize=FLAGS.quantize,
Example #23
0
    def __call__(self, features, labels, mode, params):

        # print(params)

        if mode == tf.estimator.ModeKeys.TRAIN:

            if "batch_size" not in params.keys():
                raise RuntimeError("Parameter `batch_size` is missing...")

            if "learning_rate_init" not in params.keys():
                raise RuntimeError("Parameter `learning_rate` is missing...")

            if "num_gpus" not in params.keys():
                raise RuntimeError("Parameter `num_gpus` is missing...")

            if "steps_per_epoch" not in params.keys():
                raise RuntimeError("Parameter `steps_per_epoch` is missing...")

            if "momentum" not in params.keys():
                raise RuntimeError("Parameter `momentum` is missing...")

            if "weight_decay" not in params.keys():
                raise RuntimeError("Parameter `weight_decay` is missing...")

            if "loss_scale" not in params.keys():
                raise RuntimeError("Parameter `loss_scale` is missing...")

        if mode == tf.estimator.ModeKeys.TRAIN:
            with tf.device('/cpu:0'):
                # Stage inputs on the host
                cpu_prefetch_op, (features, labels) = ResnetModel._stage(
                    [features, labels])

            with tf.device('/gpu:0'):
                # Stage inputs to the device
                gpu_prefetch_op, (features, labels) = ResnetModel._stage(
                    [features, labels])

        with tf.device("/gpu:0"):

            if True:  # not params['use_trt']:

                if features.dtype != self.model_hparams.dtype:
                    features = tf.cast(features, self.model_hparams.dtype)

                # Subtract mean per channel
                # and enforce values between [-1, 1]
                # features = normalized_inputs(features)

                # Update Global Step
                global_step = tf.train.get_or_create_global_step()
                tf.identity(global_step, name="global_step_ref")

                # tf.identity(features, name="features_ref")
                # tf.identity(labels, name="labels_ref")

                probs, logits = self.build_model(
                    features,
                    training=mode == tf.estimator.ModeKeys.TRAIN,
                    reuse=False)

            else:

                trt_graph = trt.create_inference_graph(
                    input_graph_def=None,
                    outputs=None,
                    input_saved_model_dir=os.path.join(
                        self.model_hparams.model_dir, '1554216247'),
                    input_saved_model_tags=['serve'],
                    max_batch_size=params["batch_size"],
                    max_workspace_size_bytes=1 << 20,
                    precision_mode="FP32")

                for node in trt_graph.node:
                    print(node.name)

                y_preds = tf.import_graph_def(
                    trt_graph,
                    return_elements=['resnet50_v1.5/output/softmax:0'])

                predictions = {'classes': y_preds[0]}

                return tf.estimator.EstimatorSpec(
                    mode=tf.estimator.ModeKeys.PREDICT,
                    predictions=predictions)

            y_preds = tf.argmax(logits, axis=1, output_type=tf.int32)

            # Check the output dtype, shall be FP32 in training
            assert (probs.dtype == tf.float32)
            assert (logits.dtype == tf.float32)
            assert (y_preds.dtype == tf.int32)

            tf.identity(logits, name="logits_ref")
            tf.identity(probs, name="probs_ref")
            tf.identity(y_preds, name="y_preds_ref")

            if mode == tf.estimator.ModeKeys.TRAIN:

                assert (len(tf.trainable_variables()) == 161)

            else:

                assert (len(tf.trainable_variables()) == 0)

        if mode == tf.estimator.ModeKeys.PREDICT:

            predictions = {'classes': y_preds, 'probabilities': probs}

            return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs={
                    'predict': tf.estimator.export.PredictOutput(predictions)
                })

        else:

            with tf.device("/gpu:0"):

                if mode == tf.estimator.ModeKeys.TRAIN:
                    acc_top1 = tf.nn.in_top_k(predictions=logits,
                                              targets=labels,
                                              k=1)
                    acc_top5 = tf.nn.in_top_k(predictions=logits,
                                              targets=labels,
                                              k=5)

                else:
                    acc_top1, acc_top1_update_op = tf.metrics.mean(
                        tf.nn.in_top_k(predictions=logits, targets=labels,
                                       k=1))
                    acc_top5, acc_top5_update_op = tf.metrics.mean(
                        tf.nn.in_top_k(predictions=logits, targets=labels,
                                       k=5))

                tf.identity(acc_top1, name="acc_top1_ref")
                tf.identity(acc_top5, name="acc_top5_ref")

                predictions = {
                    'classes': y_preds,
                    'probabilities': probs,
                    'accuracy_top1': acc_top1,
                    'accuracy_top5': acc_top5
                }

                cross_entropy = tf.losses.sparse_softmax_cross_entropy(
                    logits=logits, labels=labels)

                assert (cross_entropy.dtype == tf.float32)
                tf.identity(cross_entropy, name='cross_entropy_loss_ref')

                def loss_filter_fn(name):
                    """we don't need to compute L2 loss for BN and bias (eq. to add a cste)"""
                    return all([
                        tensor_name not in name.lower()
                        # for tensor_name in ["batchnorm", "batch_norm", "batch_normalization", "bias"]
                        for tensor_name in
                        ["batchnorm", "batch_norm", "batch_normalization"]
                    ])

                filtered_params = [
                    tf.cast(v, tf.float32) for v in tf.trainable_variables()
                    if loss_filter_fn(v.name)
                ]

                if len(filtered_params) != 0:

                    l2_loss_per_vars = [
                        tf.nn.l2_loss(v) for v in filtered_params
                    ]
                    l2_loss = tf.multiply(tf.add_n(l2_loss_per_vars),
                                          params["weight_decay"])

                else:
                    l2_loss = tf.zeros(shape=(), dtype=tf.float32)

                assert (l2_loss.dtype == tf.float32)
                tf.identity(l2_loss, name='l2_loss_ref')

                total_loss = tf.add(cross_entropy, l2_loss, name="total_loss")

                assert (total_loss.dtype == tf.float32)
                tf.identity(total_loss, name='total_loss_ref')

                tf.summary.scalar('cross_entropy', cross_entropy)
                tf.summary.scalar('l2_loss', l2_loss)
                tf.summary.scalar('total_loss', total_loss)

                if mode == tf.estimator.ModeKeys.TRAIN:

                    with tf.device("/cpu:0"):

                        learning_rate = learning_rate_scheduler(
                            learning_rate_init=params["learning_rate_init"],
                            global_step=global_step,
                            batch_size=params["batch_size"],
                            num_batches_per_epoch=params["steps_per_epoch"],
                            num_gpus=params["num_gpus"])

                    tf.identity(learning_rate, name='learning_rate_ref')
                    tf.summary.scalar('learning_rate', learning_rate)

                    optimizer = tf.train.MomentumOptimizer(
                        learning_rate=learning_rate,
                        momentum=params["momentum"])

                    if params["apply_loss_scaling"]:
                        optimizer = FixedLossScalerOptimizer(
                            optimizer, scale=params["loss_scale"])

                    if hvd_utils.is_using_hvd():
                        optimizer = hvd.DistributedOptimizer(optimizer)

                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    if mode != tf.estimator.ModeKeys.TRAIN:
                        update_ops += [acc_top1_update_op, acc_top5_update_op]

                    deterministic = True
                    gate_gradients = (tf.train.Optimizer.GATE_OP
                                      if deterministic else
                                      tf.train.Optimizer.GATE_NONE)

                    backprop_op = optimizer.minimize(
                        total_loss,
                        gate_gradients=gate_gradients,
                        global_step=global_step)

                    train_ops = tf.group(backprop_op,
                                         cpu_prefetch_op,
                                         gpu_prefetch_op,
                                         update_ops,
                                         name='train_ops')

                    return tf.estimator.EstimatorSpec(mode=mode,
                                                      loss=total_loss,
                                                      train_op=train_ops)

                elif mode == tf.estimator.ModeKeys.EVAL:

                    eval_metrics = {
                        "top1_accuracy": (acc_top1, acc_top1_update_op),
                        "top5_accuracy": (acc_top5, acc_top5_update_op)
                    }

                    return tf.estimator.EstimatorSpec(
                        mode=mode,
                        predictions=predictions,
                        loss=total_loss,
                        eval_metric_ops=eval_metrics)

                else:
                    raise NotImplementedError('Unknown mode {}'.format(mode))
Example #24
0
    def evaluate(
        self,
        iter_unit,
        num_iter,
        batch_size,
        warmup_steps=50,
        log_every_n_steps=1,
        is_benchmark=False
    ):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError('`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])' % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for evaluation!')

        if hvd_utils.is_using_hvd() and hvd.rank() != 0:
            raise RuntimeError('Multi-GPU inference is not supported')

        estimator_params = {}
            
        image_classifier = self._get_estimator(
            mode='validation',
            run_params=estimator_params,
            use_xla=self.run_hparams.use_xla,
            use_dali=self.run_hparams.use_dali,
            gpu_memory_fraction=self.run_hparams.gpu_memory_fraction
        )

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
                data_dir=self.run_hparams.data_dir,
                mode="validation",
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=batch_size,
            )
        
        else:
            num_epochs = 1
            num_decay_steps = -1
            num_steps = num_iter
       
    
        if self.run_hparams.data_idx_dir is not None:
            idx_filenames = runner_utils.parse_dali_idx_dataset(
                data_idx_dir=self.run_hparams.data_idx_dir,
                mode="validation"
            )
    
        eval_hooks = []
        
        if hvd.rank() == 0:
            if is_benchmark:
                
                benchmark_logging_hook = hooks.BenchmarkLoggingHook(
                    log_file_path=os.path.join(self.run_hparams.log_dir, "eval_benchmark.json"),
                    global_batch_size=batch_size,
                    log_every=log_every_n_steps,
                    warmup_steps=warmup_steps
                )
                eval_hooks.append(benchmark_logging_hook)

            LOGGER.log('Starting Model Evaluation...')
            LOGGER.log("Evaluation Epochs", num_epochs)
            LOGGER.log("Evaluation Steps", num_steps)
            LOGGER.log("Decay Steps", num_decay_steps)
            LOGGER.log("Global Batch Size", batch_size)

        def evaluation_data_fn():
    
            if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
                if hvd.rank() == 0:
                    LOGGER.log("Using DALI input... ")
                    
                return data_utils.get_dali_input_fn(
                    filenames=filenames,
                    idx_filenames=idx_filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=False,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False if self.run_hparams.seed is None else True
                )
            
            
            elif self.run_hparams.data_dir is not None:
                return data_utils.get_tfrecords_input_fn(
                        filenames=filenames,
                        batch_size=batch_size,
                        height=self.run_hparams.height,
                        width=self.run_hparams.width,
                        training=False,
                        distort_color=self.run_hparams.distort_colors,
                        num_threads=self.run_hparams.num_preprocessing_threads,
                        deterministic=False if self.run_hparams.seed is None else True
                    )

            else:
                LOGGER.log("Using Synthetic Data ...\n")
                return data_utils.get_synth_input_fn(
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    num_classes=self.run_hparams.n_classes,
                    dtype=self.run_hparams.dtype,
                )

        try:
            eval_results = image_classifier.evaluate(
                input_fn=evaluation_data_fn,
                steps=num_steps,
                hooks=eval_hooks,
            )
            LOGGER.log('Top-1 Accuracy: %.3f' % float(eval_results['top1_accuracy'] * 100))
            LOGGER.log('Top-5 Accuracy: %.3f' % float(eval_results['top5_accuracy'] * 100))

        except KeyboardInterrupt:
            print("Keyboard interrupt")

        LOGGER.log('Ending Model Evaluation ...')
Example #25
0
    def _get_global_batch_size(worker_batch_size):

        if hvd_utils.is_using_hvd():
            return worker_batch_size * hvd.size()
        else:
            return worker_batch_size
Example #26
0
    def __call__(self, features, labels, mode, params):

        if "debug_verbosity" not in params.keys():
            raise RuntimeError("Parameter `debug_verbosity` is missing...")

        if mode == tf.estimator.ModeKeys.TRAIN:

            if "rmsprop_decay" not in params.keys():
                raise RuntimeError("Parameter `rmsprop_decay` is missing...")

            if "rmsprop_momentum" not in params.keys():
                raise RuntimeError(
                    "Parameter `rmsprop_momentum` is missing...")

            if "learning_rate" not in params.keys():
                raise RuntimeError("Parameter `learning_rate` is missing...")

            if "learning_rate_decay_steps" not in params.keys():
                raise RuntimeError("Parameter `learning_rate` is missing...")

            if "learning_rate_decay_factor" not in params.keys():
                raise RuntimeError("Parameter `learning_rate` is missing...")

            if "weight_decay" not in params.keys():
                raise RuntimeError("Parameter `weight_decay` is missing...")

            if "loss_fn_name" not in params.keys():
                raise RuntimeError("Parameter `loss_fn_name` is missing...")

        if mode == tf.estimator.ModeKeys.PREDICT:
            y_pred, y_pred_logits = self.build_model(
                features,
                training=False,
                reuse=False,
                debug_verbosity=params["debug_verbosity"])

            predictions = {'logits': y_pred}
            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=predictions)

        input_image, mask_image = features

        with tf.device("/gpu:0"):

            tf.identity(input_image, name="input_image_ref")
            tf.identity(mask_image, name="mask_image_ref")
            tf.identity(labels, name="labels_ref")

            y_pred, y_pred_logits = self.build_model(
                input_image,
                training=mode == tf.estimator.ModeKeys.TRAIN,
                reuse=False,
                debug_verbosity=params["debug_verbosity"])

            all_trainable_vars = tf.reduce_sum(
                [tf.reduce_prod(v.shape) for v in tf.trainable_variables()])
            tf.identity(all_trainable_vars,
                        name='trainable_parameters_count_ref')

            if mode == tf.estimator.ModeKeys.EVAL:
                eval_metrics = dict()

            # ==================== Samples ==================== #

            image_uint8 = tf.cast((input_image + 1) * 127.5, dtype=tf.uint8)
            input_image_jpeg = tf.image.encode_jpeg(image_uint8[0],
                                                    format='grayscale',
                                                    quality=100)
            tf.identity(input_image_jpeg, name="input_image_jpeg_ref")

            for threshold in [
                    None, 0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99
            ]:
                binarize_img, binarize_img_jpeg = image_processing.binarize_output(
                    y_pred[0], threshold=threshold)

                tf.identity(binarize_img_jpeg,
                            name="output_sample_ths_%s_ref" % threshold)
                tf.summary.image('output_sample_ths_%s' % threshold,
                                 binarize_img, 10)

            # ==============+ Evaluation Metrics ==================== #

            with tf.name_scope("IoU_Metrics"):

                for threshold in [
                        0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99
                ]:

                    iou_score = metrics.iou_score(y_pred=y_pred,
                                                  y_true=mask_image,
                                                  threshold=threshold)

                    tf.identity(iou_score,
                                name='iou_score_ths_%s_ref' % threshold)
                    tf.summary.scalar('iou_score_ths_%s' % threshold,
                                      iou_score)

                    if mode == tf.estimator.ModeKeys.EVAL:
                        eval_metrics["IoU_THS_%s" %
                                     threshold] = tf.metrics.mean(iou_score)

            labels = tf.cast(labels, tf.float32)
            labels_preds = tf.reduce_max(y_pred, axis=(1, 2, 3))

            assert (abs(labels_preds - tf.clip_by_value(labels_preds, 0, 1)) <
                    0.00001,
                    "Clipping labels_preds introduces non-trivial loss.")
            labels_preds = tf.clip_by_value(labels_preds, 0, 1)

            with tf.variable_scope("Confusion_Matrix") as scope:

                tp, update_tp = tf.metrics.true_positives_at_thresholds(
                    labels=labels,
                    predictions=labels_preds,
                    thresholds=[
                        0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99
                    ],
                )

                tn, update_tn = tf.metrics.true_negatives_at_thresholds(
                    labels=labels,
                    predictions=labels_preds,
                    thresholds=[
                        0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99
                    ],
                )

                fp, update_fp = tf.metrics.false_positives_at_thresholds(
                    labels=labels,
                    predictions=labels_preds,
                    thresholds=[
                        0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99
                    ],
                )

                fn, update_fn = tf.metrics.false_negatives_at_thresholds(
                    labels=labels,
                    predictions=labels_preds,
                    thresholds=[
                        0.05, 0.125, 0.25, 0.5, 0.75, 0.85, 0.95, 0.99
                    ],
                )

                if mode == tf.estimator.ModeKeys.TRAIN:
                    local_vars = tf.get_collection(
                        tf.GraphKeys.LOCAL_VARIABLES, scope=scope.name)
                    confusion_matrix_reset_op = tf.initializers.variables(
                        local_vars, name='reset_op')

                    with tf.control_dependencies([confusion_matrix_reset_op]):
                        with tf.control_dependencies(
                            [update_tp, update_tn, update_fp, update_fn]):
                            tp = tf.identity(tp)
                            tn = tf.identity(tn)
                            fp = tf.identity(fp)
                            fn = tf.identity(fn)

                else:
                    eval_metrics["Confusion_Matrix_TP"] = tp, update_tp
                    eval_metrics["Confusion_Matrix_TN"] = tn, update_tn
                    eval_metrics["Confusion_Matrix_FP"] = fp, update_fp
                    eval_metrics["Confusion_Matrix_FN"] = fn, update_fn

                tf.identity(tp, name='true_positives_ref'
                            )  # Confusion_Matrix/true_positives_ref:0
                tf.identity(tn, name='true_negatives_ref'
                            )  # Confusion_Matrix/true_negatives_ref:0
                tf.identity(fp, name='false_positives_ref'
                            )  # Confusion_Matrix/false_positives_ref:0
                tf.identity(fn, name='false_negatives_ref'
                            )  # Confusion_Matrix/false_negatives_ref:0

                tf.summary.scalar('true_positives', tp[3])  # For Ths = 0.5
                tf.summary.scalar('true_negatives', tn[3])  # For Ths = 0.5
                tf.summary.scalar('false_positives', fp[3])  # For Ths = 0.5
                tf.summary.scalar('false_negatives', fn[3])  # For Ths = 0.5

            binarized_mask, binarized_mask_jpeg = image_processing.binarize_output(
                mask_image[0], threshold=0.5)
            tf.identity(binarized_mask_jpeg, name="mask_sample_ref")
            tf.summary.image('sample_mask', binarized_mask, 10)

            ##########################

            mask_max_val = tf.reduce_max(mask_image)
            tf.identity(mask_max_val, name='mask_max_val_ref')

            mask_min_val = tf.reduce_min(mask_image)
            tf.identity(mask_min_val, name='mask_min_val_ref')

            mask_mean_val = tf.reduce_mean(mask_image)
            tf.identity(mask_mean_val, name='mask_mean_val_ref')

            mask_std_val = tf.math.reduce_std(mask_image)
            tf.identity(mask_std_val, name='mask_std_val_ref')

            ##########################

            output_max_val = tf.reduce_max(y_pred)
            tf.identity(output_max_val, name='output_max_val_ref')

            output_min_val = tf.reduce_min(y_pred)
            tf.identity(output_min_val, name='output_min_val_ref')

            output_mean_val = tf.reduce_mean(y_pred)
            tf.identity(output_mean_val, name='output_mean_val_ref')

            output_std_val = tf.math.reduce_std(y_pred)
            tf.identity(output_std_val, name='output_std_val_ref')

            with tf.variable_scope("losses"):

                # ==============+ Reconstruction Loss ==================== #

                if params["loss_fn_name"] == "x-entropy":
                    reconstruction_loss = losses.reconstruction_x_entropy(
                        y_pred=y_pred, y_true=mask_image)

                elif params["loss_fn_name"] == "l2_loss":
                    reconstruction_loss = losses.reconstruction_l2loss(
                        y_pred=y_pred, y_true=mask_image)

                elif params["loss_fn_name"] == "dice_sorensen":
                    reconstruction_loss = 1 - losses.dice_coe(
                        y_pred=y_pred, y_true=mask_image, loss_type='sorensen')

                elif params["loss_fn_name"] == "dice_jaccard":
                    reconstruction_loss = 1 - losses.dice_coe(
                        y_pred=y_pred, y_true=mask_image, loss_type='jaccard')

                elif params["loss_fn_name"] == "adaptive_loss":
                    reconstruction_loss = losses.adaptive_loss(
                        y_pred=y_pred,
                        y_pred_logits=y_pred_logits,
                        y_true=mask_image,
                        switch_at_threshold=0.3,
                        loss_type='sorensen')

                else:
                    raise ValueError("Unknown loss function received: %s" %
                                     params["loss_fn_name"])

                tf.identity(reconstruction_loss,
                            name='reconstruction_loss_ref')
                tf.summary.scalar('reconstruction_loss', reconstruction_loss)

                if mode == tf.estimator.ModeKeys.TRAIN:

                    # ============== Regularization Loss ==================== #

                    l2_loss = losses.regularization_l2loss(
                        weight_decay=params["weight_decay"])

                    tf.identity(l2_loss, name='l2_loss_ref')
                    tf.summary.scalar('l2_loss', l2_loss)

                    total_loss = tf.add(reconstruction_loss,
                                        l2_loss,
                                        name="total_loss")

                else:
                    total_loss = reconstruction_loss

                tf.identity(total_loss, name='total_loss_ref')
                tf.summary.scalar('total_loss', total_loss)

            if mode == tf.estimator.ModeKeys.TRAIN:

                with tf.variable_scope("optimizers"):

                    # Update Global Step
                    global_step = tf.train.get_or_create_global_step()
                    tf.identity(global_step, name="global_step_ref")

                    learning_rate = tf.train.exponential_decay(
                        learning_rate=params["learning_rate"],
                        decay_steps=params["learning_rate_decay_steps"],
                        decay_rate=params["learning_rate_decay_factor"],
                        global_step=global_step,
                        staircase=True)

                    tf.identity(learning_rate, name="learning_rate_ref")
                    tf.summary.scalar('learning_rate_ref', learning_rate)

                    opt = tf.train.RMSPropOptimizer(
                        learning_rate=learning_rate,
                        use_locking=False,
                        centered=True,
                        decay=params["rmsprop_decay"],
                        momentum=params["rmsprop_momentum"],
                    )

                    if hvd_utils.is_using_hvd():
                        # Apply gradient compression using GRACE.
                        from grace_dl.tensorflow.communicator.allgather import Allgather
                        from grace_dl.tensorflow.compressor.topk import TopKCompressor
                        from grace_dl.tensorflow.memory.residual import ResidualMemory

                        world_size = hvd.size()
                        grc = Allgather(TopKCompressor(0.3), ResidualMemory(),
                                        world_size)
                        opt = hvd.DistributedOptimizer(opt,
                                                       grace=grc,
                                                       device_dense='/gpu:0')

                    if params["apply_manual_loss_scaling"]:

                        # if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                        #     Logger.log("Applying manual Loss Scaling ...")

                        loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(
                            init_loss_scale=2**32,  # 4,294,967,296
                            incr_every_n_steps=1000)
                        opt = tf.contrib.mixed_precision.LossScaleOptimizer(
                            opt, loss_scale_manager)

                    deterministic = True
                    gate_gradients = (tf.train.Optimizer.GATE_OP
                                      if deterministic else
                                      tf.train.Optimizer.GATE_NONE)

                    backprop_op = opt.minimize(total_loss,
                                               gate_gradients=gate_gradients,
                                               global_step=global_step)

                    train_op = tf.group(
                        backprop_op,
                        tf.get_collection(tf.GraphKeys.UPDATE_OPS))

                    return tf.estimator.EstimatorSpec(
                        mode,
                        loss=total_loss,
                        train_op=train_op,
                    )

            elif mode == tf.estimator.ModeKeys.EVAL:

                return tf.estimator.EstimatorSpec(
                    mode,
                    loss=total_loss,
                    eval_metric_ops=eval_metrics,
                    predictions={"output": y_pred})

            else:
                raise NotImplementedError('Unknown mode {}'.format(mode))
Example #27
0
    def train(self,
              iter_unit,
              num_iter,
              run_iter,
              batch_size,
              warmup_steps=50,
              weight_decay=1e-4,
              lr_init=0.1,
              lr_warmup_epochs=5,
              momentum=0.9,
              log_every_n_steps=1,
              loss_scale=256,
              label_smoothing=0.0,
              mixup=0.0,
              use_cosine_lr=False,
              use_static_loss_scaling=False,
              is_benchmark=False,
              quantize=False,
              symmetric=False,
              quant_delay=0,
              finetune_checkpoint=None,
              use_final_conv=False,
              use_qdq=False):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError(
                '`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])'
                % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for training!')

        if self.run_hparams.use_tf_amp or self.run_hparams.dtype == tf.float16:
            if use_static_loss_scaling:
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "0"
            else:
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"
        else:
            use_static_loss_scaling = False  # Make sure it hasn't been set to True on FP32 training

        num_gpus = 1 if not hvd_utils.is_using_hvd() else hvd.size()
        global_batch_size = batch_size * num_gpus

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
                data_dir=self.run_hparams.data_dir,
                mode="train",
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=global_batch_size,
            )

            steps_per_epoch = num_steps / num_epochs

        else:
            num_epochs = 1
            num_steps = num_iter
            steps_per_epoch = num_steps
            num_decay_steps = num_steps
            num_samples = num_steps * batch_size

        if run_iter == -1:
            run_iter = num_steps
        else:
            run_iter = steps_per_epoch * run_iter if iter_unit == "epoch" else run_iter

        if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
            idx_filenames = runner_utils.parse_dali_idx_dataset(
                data_idx_dir=self.run_hparams.data_idx_dir, mode="train")

        training_hooks = []

        if hvd.rank() == 0:
            print('Starting Model Training...')
            print("Training Epochs", num_epochs)
            print("Total Steps", num_steps)
            print("Steps per Epoch", steps_per_epoch)
            print("Decay Steps", num_decay_steps)
            print("Weight Decay Factor", weight_decay)
            print("Init Learning Rate", lr_init)
            print("Momentum", momentum)
            print("Num GPUs", num_gpus)
            print("Per-GPU Batch Size", batch_size)

            if is_benchmark:
                self.training_logging_hook = hooks.BenchmarkLoggingHook(
                    global_batch_size=global_batch_size,
                    warmup_steps=warmup_steps)
            else:
                self.training_logging_hook = hooks.TrainingLoggingHook(
                    global_batch_size=global_batch_size,
                    num_steps=num_steps,
                    num_samples=num_samples,
                    num_epochs=num_epochs,
                    steps_per_epoch=steps_per_epoch)
            training_hooks.append(self.training_logging_hook)

        if hvd_utils.is_using_hvd():
            bcast_hook = hvd.BroadcastGlobalVariablesHook(0)
            training_hooks.append(bcast_hook)

        training_hooks.append(hooks.PrefillStagingAreasHook())
        training_hooks.append(hooks.TrainingPartitionHook())

        estimator_params = {
            'batch_size': batch_size,
            'steps_per_epoch': steps_per_epoch,
            'num_gpus': num_gpus,
            'momentum': momentum,
            'lr_init': lr_init,
            'lr_warmup_epochs': lr_warmup_epochs,
            'weight_decay': weight_decay,
            'loss_scale': loss_scale,
            'apply_loss_scaling': use_static_loss_scaling,
            'label_smoothing': label_smoothing,
            'mixup': mixup,
            'num_decay_steps': num_decay_steps,
            'use_cosine_lr': use_cosine_lr,
            'use_final_conv': use_final_conv,
            'quantize': quantize,
            'use_qdq': use_qdq,
            'symmetric': symmetric,
            'quant_delay': quant_delay
        }

        if finetune_checkpoint:
            estimator_params['finetune_checkpoint'] = finetune_checkpoint

        image_classifier = self._get_estimator(
            mode='train',
            run_params=estimator_params,
            use_xla=self.run_hparams.use_xla,
            use_dali=self.run_hparams.use_dali,
            gpu_memory_fraction=self.run_hparams.gpu_memory_fraction,
            gpu_id=self.run_hparams.gpu_id)

        def training_data_fn():

            if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
                if hvd.rank() == 0:
                    print("Using DALI input... ")

                return data_utils.get_dali_input_fn(
                    filenames=filenames,
                    idx_filenames=idx_filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            elif self.run_hparams.data_dir is not None:

                return data_utils.get_tfrecords_input_fn(
                    filenames=filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            else:
                if hvd.rank() == 0:
                    print("Using Synthetic Data ...")
                return data_utils.get_synth_input_fn(
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    num_classes=self.run_hparams.n_classes,
                    dtype=self.run_hparams.dtype,
                )

        try:
            current_step = image_classifier.get_variable_value("global_step")
        except ValueError:
            current_step = 0

        run_iter = max(0, min(run_iter, num_steps - current_step))
        print("Current step:", current_step)

        if run_iter > 0:
            try:
                image_classifier.train(
                    input_fn=training_data_fn,
                    steps=run_iter,
                    hooks=training_hooks,
                )
            except KeyboardInterrupt:
                print("Keyboard interrupt")

        if hvd.rank() == 0:
            if run_iter > 0:
                print('Ending Model Training ...')
                train_throughput = self.training_logging_hook.mean_throughput.value(
                )
                train_time = self.training_logging_hook.train_time
                dllogger.log(data={'train_throughput': train_throughput},
                             step=tuple())
                dllogger.log(data={'Total Training time': train_time},
                             step=tuple())
            else:
                print(
                    'Model already trained required number of steps. Skipped')
    def dataset_fn(self,
                   batch_size,
                   training,
                   input_shape,
                   mask_shape,
                   num_threads,
                   use_gpu_prefetch,
                   normalize_data_method,
                   only_defective_images,
                   augment_data,
                   seed=None):

        super(DAGM2007_Dataset, self).dataset_fn(
            batch_size=batch_size,
            training=training,
            input_shape=input_shape,
            mask_shape=mask_shape,
            num_threads=num_threads,
            use_gpu_prefetch=use_gpu_prefetch,
            normalize_data_method=
            normalize_data_method,  # [None, "zero_centered", "zero_one"]
            only_defective_images=only_defective_images,
            augment_data=augment_data,
            seed=seed)

        shuffle_buffer_size = 10000

        image_dir, csv_file = self._get_data_dirs(training=training)

        mask_image_dir = os.path.join(image_dir, "Label")

        dataset = tf.data.TextLineDataset(csv_file)

        dataset = dataset.skip(1)  # Skip CSV Header

        if only_defective_images:
            dataset = dataset.filter(
                lambda line: tf.not_equal(tf.strings.substr(line, -1, 1), "0"))

        if hvd_utils.is_using_hvd() and training:
            dataset = dataset.shard(hvd.size(), hvd.rank())

        def _load_dagm_data(line):

            input_image_name, image_mask_name, label = tf.decode_csv(
                line, record_defaults=[[""], [""], [0]], field_delim=',')

            def decode_image(filepath, resize_shape, normalize_data_method):
                image_content = tf.read_file(filepath)

                # image = tf.image.decode_image(image_content, channels=resize_shape[-1])
                image = tf.image.decode_png(contents=image_content,
                                            channels=resize_shape[-1],
                                            dtype=tf.uint8)

                image = tf.image.resize_images(
                    image,
                    size=resize_shape[:2],
                    method=tf.image.ResizeMethod.
                    BILINEAR,  # [BILINEAR, NEAREST_NEIGHBOR, BICUBIC, AREA]
                    align_corners=False,
                    preserve_aspect_ratio=True)

                image.set_shape(resize_shape)
                image = tf.cast(image, tf.float32)

                if normalize_data_method == "zero_centered":
                    image = tf.divide(image, 127.5) - 1

                elif normalize_data_method == "zero_one":
                    image = tf.divide(image, 255.0)

                return image

            input_image = decode_image(
                filepath=tf.strings.join([image_dir, input_image_name],
                                         separator='/'),
                resize_shape=input_shape,
                normalize_data_method=normalize_data_method,
            )

            mask_image = tf.cond(
                tf.equal(image_mask_name, ""),
                true_fn=lambda: tf.zeros(mask_shape, dtype=tf.float32),
                false_fn=lambda: decode_image(
                    filepath=tf.strings.join([mask_image_dir, image_mask_name],
                                             separator='/'),
                    resize_shape=mask_shape,
                    normalize_data_method="zero_one",
                ),
            )

            label = tf.cast(label, tf.int32)

            return tf.data.Dataset.from_tensor_slices(
                ([input_image], [mask_image], [label]))

        dataset = dataset.apply(
            tf.data.experimental.parallel_interleave(
                _load_dagm_data,
                cycle_length=batch_size * 8,
                block_length=4,
                buffer_output_elements=batch_size * 8))

        dataset = dataset.cache()

        if training:
            dataset = dataset.apply(
                tf.data.experimental.shuffle_and_repeat(
                    buffer_size=shuffle_buffer_size, seed=seed))

        else:
            dataset = dataset.repeat()

        def _augment_data(input_image, mask_image, label):

            if augment_data:

                if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                    print("Using data augmentation ...")

                #input_image = tf.image.per_image_standardization(input_image)

                horizontal_flip = tf.random_uniform(shape=(), seed=seed) > 0.5
                input_image = tf.cond(
                    horizontal_flip,
                    lambda: tf.image.flip_left_right(input_image),
                    lambda: input_image)
                mask_image = tf.cond(
                    horizontal_flip,
                    lambda: tf.image.flip_left_right(mask_image),
                    lambda: mask_image)

                n_rots = tf.random_uniform(shape=(),
                                           dtype=tf.int32,
                                           minval=0,
                                           maxval=3,
                                           seed=seed)
                input_image = tf.image.rot90(input_image, k=n_rots)
                mask_image = tf.image.rot90(mask_image, k=n_rots)

            return (input_image, mask_image), label

        dataset = dataset.apply(
            tf.data.experimental.map_and_batch(
                map_func=_augment_data,
                num_parallel_calls=num_threads,
                batch_size=batch_size,
                drop_remainder=True,
            ))

        dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)

        if use_gpu_prefetch:
            dataset.apply(
                tf.data.experimental.prefetch_to_device(device="/gpu:0",
                                                        buffer_size=4))

        return dataset
Example #29
0
    def evaluate(
        self,
        iter_unit,
        num_iter,
        batch_size,
        warmup_steps=50,
        log_every_n_steps=1,
        is_benchmark=False,
        export_dir=None,
        quantize=False,
        symmetric=False,
        use_qdq=False,
        use_final_conv=False,
    ):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError(
                '`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])'
                % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for evaluation!')

        if hvd_utils.is_using_hvd() and hvd.rank() != 0:
            raise RuntimeError('Multi-GPU inference is not supported')

        estimator_params = {
            'quantize': quantize,
            'symmetric': symmetric,
            'use_qdq': use_qdq,
            'use_final_conv': use_final_conv
        }

        image_classifier = self._get_estimator(
            mode='validation',
            run_params=estimator_params,
            use_xla=self.run_hparams.use_xla,
            use_dali=self.run_hparams.use_dali,
            gpu_memory_fraction=self.run_hparams.gpu_memory_fraction,
            gpu_id=self.run_hparams.gpu_id)

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
                data_dir=self.run_hparams.data_dir,
                mode="validation",
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=batch_size,
            )

        else:
            num_epochs = 1
            num_decay_steps = -1
            num_steps = num_iter

        if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
            idx_filenames = runner_utils.parse_dali_idx_dataset(
                data_idx_dir=self.run_hparams.data_idx_dir, mode="validation")

        eval_hooks = []

        if hvd.rank() == 0:
            self.eval_logging_hook = hooks.BenchmarkLoggingHook(
                global_batch_size=batch_size, warmup_steps=warmup_steps)
            eval_hooks.append(self.eval_logging_hook)

            print('Starting Model Evaluation...')
            print("Evaluation Epochs", num_epochs)
            print("Evaluation Steps", num_steps)
            print("Decay Steps", num_decay_steps)
            print("Global Batch Size", batch_size)

        def evaluation_data_fn():

            if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
                if hvd.rank() == 0:
                    print("Using DALI input... ")

                return data_utils.get_dali_input_fn(
                    filenames=filenames,
                    idx_filenames=idx_filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=False,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            elif self.run_hparams.data_dir is not None:
                return data_utils.get_tfrecords_input_fn(
                    filenames=filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=False,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False
                    if self.run_hparams.seed is None else True)

            else:
                print("Using Synthetic Data ...\n")
                return data_utils.get_synth_input_fn(
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    num_classes=self.run_hparams.n_classes,
                    dtype=self.run_hparams.dtype,
                )

        try:
            eval_results = image_classifier.evaluate(
                input_fn=evaluation_data_fn,
                steps=num_steps,
                hooks=eval_hooks,
            )

            eval_throughput = self.eval_logging_hook.mean_throughput.value()
            eval_latencies = np.array(self.eval_logging_hook.latencies) * 1000
            eval_latencies_q = np.quantile(eval_latencies, q=[0.9, 0.95, 0.99])
            eval_latencies_mean = np.mean(eval_latencies)

            dllogger.log(data={
                'top1_accuracy': float(eval_results['top1_accuracy']),
                'top5_accuracy': float(eval_results['top5_accuracy']),
                'eval_throughput': eval_throughput,
                'eval_latency_avg': eval_latencies_mean,
                'eval_latency_p90': eval_latencies_q[0],
                'eval_latency_p95': eval_latencies_q[1],
                'eval_latency_p99': eval_latencies_q[2],
            },
                         step=tuple())

            if export_dir is not None:
                dllogger.log(data={'export_dir': export_dir}, step=tuple())
                input_receiver_fn = data_utils.get_serving_input_receiver_fn(
                    batch_size=None,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    dtype=self.run_hparams.dtype)

                image_classifier.export_savedmodel(export_dir,
                                                   input_receiver_fn)

        except KeyboardInterrupt:
            print("Keyboard interrupt")

        print('Model evaluation finished')
Example #30
0
    def __init__(
            self,

            # Model Params
            input_format,  # NCHW or NHWC
            compute_format,  # NCHW or NHWC
            n_channels,
            activation_fn,
            weight_init_method,
            model_variant,
            input_shape,
            mask_shape,
            input_normalization_method,

            # Training HParams
            augment_data,
            loss_fn_name,

            #  Runtime HParams
            amp,
            xla,

            # Directory Params
            model_dir=None,
            log_dir=None,
            sample_dir=None,
            data_dir=None,
            dataset_name=None,
            dataset_hparams=None,

            # Debug Params
            log_every_n_steps=1,
            debug_verbosity=0,
            seed=None):

        if dataset_hparams is None:
            dataset_hparams = dict()

        if compute_format not in ["NHWC", 'NCHW']:
            raise ValueError(
                "Unknown `compute_format` received: %s (allowed: ['NHWC', 'NCHW'])"
                % compute_format)

        if input_format not in ["NHWC", 'NCHW']:
            raise ValueError(
                "Unknown `input_format` received: %s (allowed: ['NHWC', 'NCHW'])"
                % input_format)

        if n_channels not in [1, 3]:
            raise ValueError(
                "Unsupported number of channels: %d (allowed: 1 (grayscale) and 3 (color))"
                % n_channels)

        if data_dir is not None and not os.path.exists(data_dir):
            raise ValueError("The `data_dir` received does not exists: %s" %
                             data_dir)

        if hvd_utils.is_using_hvd():
            hvd.init()

            if hvd.rank() == 0:
                print("Horovod successfully initialized ...")

            tf_seed = 2 * (seed + hvd.rank()) if seed is not None else None

        else:
            tf_seed = 2 * seed if seed is not None else None

        # ============================================
        # Optimisation Flags - Do not remove
        # ============================================

        os.environ['CUDA_CACHE_DISABLE'] = '0'

        os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL'

        os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
        os.environ['TF_GPU_THREAD_COUNT'] = '1' if not hvd_utils.is_using_hvd(
        ) else str(hvd.size())
        print("WORLD_SIZE", hvd.size())

        os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'

        os.environ['TF_ADJUST_HUE_FUSED'] = '1'
        os.environ['TF_ADJUST_SATURATION_FUSED'] = '1'
        os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

        os.environ['TF_SYNC_ON_FINISH'] = '0'
        os.environ['TF_AUTOTUNE_THRESHOLD'] = '2'

        # =================================================

        self.xla = xla

        if amp:

            if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                print("TF AMP is activated - Experimental Feature")

            os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"

        # =================================================

        model_hparams = tf.contrib.training.HParams(
            # Model Params
            input_format=input_format,
            compute_format=compute_format,
            input_shape=input_shape,
            mask_shape=mask_shape,
            n_channels=n_channels,
            activation_fn=activation_fn,
            weight_init_method=weight_init_method,
            model_variant=model_variant,
            input_normalization_method=input_normalization_method,

            # Training HParams
            augment_data=augment_data,
            loss_fn_name=loss_fn_name,

            # Runtime Params
            amp=amp,

            # Debug Params
            log_every_n_steps=log_every_n_steps,
            debug_verbosity=debug_verbosity,
            seed=tf_seed)

        run_config_additional = tf.contrib.training.HParams(
            dataset_hparams=dataset_hparams,
            model_dir=model_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            log_dir=log_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            sample_dir=sample_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            data_dir=data_dir,
            num_preprocessing_threads=32,
        )

        if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
            try:
                os.makedirs(sample_dir)
            except FileExistsError:
                pass

        self.run_hparams = Runner._build_hparams(model_hparams,
                                                 run_config_additional)

        if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
            print('Defining Model Estimator ...\n')

        self._model = UNet_v1(
            model_name="UNet_v1",
            input_format=self.run_hparams.input_format,
            compute_format=self.run_hparams.compute_format,
            n_output_channels=1,
            unet_variant=self.run_hparams.model_variant,
            weight_init_method=self.run_hparams.weight_init_method,
            activation_fn=self.run_hparams.activation_fn)

        if self.run_hparams.seed is not None:

            if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                print("Deterministic Run - Seed: %d\n" % seed)

            tf.set_random_seed(self.run_hparams.seed)
            np.random.seed(self.run_hparams.seed)
            random.seed(self.run_hparams.seed)

        if dataset_name not in known_datasets.keys():
            raise RuntimeError(
                "The dataset `%s` is unknown, allowed values: %s ..." %
                (dataset_name, list(known_datasets.keys())))

        self.dataset = known_datasets[dataset_name](
            data_dir=data_dir, **self.run_hparams.dataset_hparams)

        self.num_gpus = 1 if not hvd_utils.is_using_hvd() else hvd.size()