Esempio n. 1
0
    def evaluate(self,
                 iter_unit,
                 num_iter,
                 batch_size,
                 warmup_steps=50,
                 is_benchmark=False,
                 save_eval_results_to_json=False):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError(
                '`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])'
                % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for evaluation!')

        # if hvd_utils.is_using_hvd() and hvd.rank() != 0:
        #     raise RuntimeError('Multi-GPU inference is not supported')

        print('Defining Model Estimator ...\n')

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs = self.dataset.get_dataset_runtime_specs(
                training=False,
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=batch_size)

            steps_per_epoch = num_steps / num_epochs

        else:
            num_epochs = 1
            num_steps = num_iter
            steps_per_epoch = num_steps

        evaluation_hooks = [
            ProfilerHook(global_batch_size=batch_size,
                         log_every=self.run_hparams.log_every_n_steps,
                         warmup_steps=warmup_steps,
                         is_training=False,
                         sample_dir=self.run_hparams.sample_dir)
        ]

        print('Starting Model Evaluation ...\n')

        Logger.log(step=('PARAMETER'),
                   data={"Epochs": num_epochs},
                   verbosity=Logger.Verbosity.DEFAULT)
        Logger.log(step=('PARAMETER'),
                   data={"Total Steps": num_steps},
                   verbosity=Logger.Verbosity.DEFAULT)
        Logger.log(step=('PARAMETER'),
                   data={"Steps per Epoch": steps_per_epoch},
                   verbosity=Logger.Verbosity.DEFAULT)
        Logger.log(step=('PARAMETER'),
                   data={"GPU Batch Size": batch_size},
                   verbosity=Logger.Verbosity.DEFAULT)
        Logger.log(step=('PARAMETER'),
                   data={"Total Files to Processed": num_steps * batch_size},
                   verbosity=Logger.Verbosity.DEFAULT)

        print()  # visual spacing

        estimator_params = {
            'batch_size': batch_size,
            'steps_per_epoch': steps_per_epoch,
            'loss_fn_name': self.run_hparams.loss_fn_name,
            'debug_verbosity': self.run_hparams.debug_verbosity,
        }

        def evaluation_data_fn():

            if not is_benchmark or self.run_hparams.data_dir is not None:

                return self.dataset.dataset_fn(
                    batch_size=batch_size,
                    training=False,
                    input_shape=list(self.run_hparams.input_shape) +
                    [self.run_hparams.n_channels],
                    mask_shape=list(self.run_hparams.mask_shape) +
                    [self.run_hparams.n_channels],
                    num_threads=64,
                    use_gpu_prefetch=True,
                    normalize_data_method="zero_centered",
                    only_defective_images=False,
                    augment_data=False,
                    seed=self.run_hparams.seed)

            else:
                print("Using Synthetic Data ...")

                return self.dataset.synth_dataset_fn(
                    batch_size=batch_size,
                    training=False,
                    input_shape=list(self.run_hparams.input_shape) +
                    [self.run_hparams.n_channels],
                    mask_shape=list(self.run_hparams.mask_shape) +
                    [self.run_hparams.n_channels],
                    num_threads=64,
                    use_gpu_prefetch=True,
                    normalize_data_method="zero_centered",
                    only_defective_images=False,
                    augment_data=False,
                    seed=self.run_hparams.seed)

        model = self._get_estimator(mode='validation',
                                    run_params=estimator_params,
                                    xla=self.xla)

        try:
            eval_results = model.evaluate(
                input_fn=evaluation_data_fn,
                steps=num_steps,
                hooks=evaluation_hooks,
            )

            print('Ending Model Evaluation ...')

            print(
                '###################################\n\nEvaluation Results:\n')

            data_to_log = {
                "{prefix}.{key}".format(prefix=Logger._stage, key=key):
                float(val)
                for key, val in sorted(eval_results.items(),
                                       key=operator.itemgetter(0))
                if not any(
                    val in key
                    for val in ["loss", "global_step", "Confusion_Matrix"])
            }
            Logger.log(step=(),
                       data=data_to_log,
                       verbosity=Logger.Verbosity.DEFAULT)

            fns = eval_results["Confusion_Matrix_FN"]
            fps = eval_results["Confusion_Matrix_FP"]
            tns = eval_results["Confusion_Matrix_TN"]
            tps = eval_results["Confusion_Matrix_TP"]

            positives = np.add(tps, fns)
            negatives = np.add(tns, fps)

            tpr = np.divide(tps, positives)
            tnr = np.divide(tns, negatives)

            Logger.log(
                step=(num_steps, ),
                data={
                    "{prefix}.true_positives".format(prefix=Logger._stage):
                    str(tps)
                },
                verbosity=Logger.Verbosity.DEFAULT)

            Logger.log(
                step=(num_steps, ),
                data={
                    "{prefix}.true_negatives".format(prefix=Logger._stage):
                    str(tns)
                },
                verbosity=Logger.Verbosity.DEFAULT)

            Logger.log(
                step=(num_steps, ),
                data={
                    "{prefix}.false_positives".format(prefix=Logger._stage):
                    str(fps)
                },
                verbosity=Logger.Verbosity.DEFAULT)

            Logger.log(
                step=(num_steps, ),
                data={
                    "{prefix}.false_negatives".format(prefix=Logger._stage):
                    str(fns)
                },
                verbosity=Logger.Verbosity.DEFAULT)

            Logger.log(
                step=(num_steps, ),
                data={
                    "{prefix}.true_positive_rate".format(prefix=Logger._stage):
                    str(["%.3f" % x for x in tpr])
                },
                verbosity=Logger.Verbosity.DEFAULT)

            Logger.log(
                step=(num_steps, ),
                data={
                    "{prefix}.true_negative_rate".format(prefix=Logger._stage):
                    str(["%.3f" % x for x in tnr])
                },
                verbosity=Logger.Verbosity.DEFAULT)

            if save_eval_results_to_json:

                results_dict = {
                    'IoU': {
                        '0.75': str(eval_results["IoU_THS_0.75"]),
                        '0.85': str(eval_results["IoU_THS_0.85"]),
                        '0.95': str(eval_results["IoU_THS_0.95"]),
                        '0.99': str(eval_results["IoU_THS_0.99"]),
                    },
                    'TPR': {
                        '0.75': str(tpr[-4]),
                        '0.85': str(tpr[-3]),
                        '0.95': str(tpr[-2]),
                        '0.99': str(tpr[-1]),
                    },
                    'TNR': {
                        '0.75': str(tnr[-4]),
                        '0.85': str(tnr[-3]),
                        '0.95': str(tnr[-2]),
                        '0.99': str(tnr[-1]),
                    }
                }

                with open(
                        os.path.join(self.run_hparams.model_dir, "..",
                                     "results.json"), 'w') as f:
                    json.dump(results_dict, f)

        except KeyboardInterrupt:
            print("Keyboard interrupt")
    def evaluate(self,
                 iter_unit,
                 num_iter,
                 batch_size,
                 warmup_steps=50,
                 is_benchmark=False,
                 save_eval_results_to_json=False):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError(
                '`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])'
                % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for evaluation!')

        if hvd_utils.is_using_hvd() and hvd.rank() != 0:
            raise RuntimeError('Multi-GPU inference is not supported')

        LOGGER.log('Defining Model Estimator ...\n')

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs = self.dataset.get_dataset_runtime_specs(
                training=False,
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=batch_size)

            steps_per_epoch = num_steps / num_epochs

        else:
            num_epochs = 1
            num_steps = num_iter
            steps_per_epoch = num_steps

        evaluation_hooks = [
            ProfilerHook(global_batch_size=batch_size,
                         log_every=self.run_hparams.log_every_n_steps,
                         warmup_steps=warmup_steps,
                         is_training=False,
                         sample_dir=self.run_hparams.sample_dir)
        ]

        LOGGER.log('Starting Model Evaluation ...\n')

        LOGGER.log("=> Epochs: %d" % num_epochs)
        LOGGER.log("=> Total Steps: %d" % num_steps)
        LOGGER.log("=> Steps per Epoch: %d" % steps_per_epoch)
        LOGGER.log("=> GPU Batch Size: %d" % batch_size)
        LOGGER.log("=> Total Files to Processed: %d\n" %
                   (num_steps * batch_size))

        estimator_params = {
            'batch_size': batch_size,
            'steps_per_epoch': steps_per_epoch,
            'loss_fn_name': self.run_hparams.loss_fn_name,
            'debug_verbosity': self.run_hparams.debug_verbosity,
        }

        def evaluation_data_fn():

            if not is_benchmark or self.run_hparams.data_dir is not None:

                return self.dataset.dataset_fn(
                    batch_size=batch_size,
                    training=False,
                    input_shape=list(self.run_hparams.input_shape) +
                    [self.run_hparams.n_channels],
                    mask_shape=list(self.run_hparams.mask_shape) +
                    [self.run_hparams.n_channels],
                    num_threads=64,
                    use_gpu_prefetch=True,
                    normalize_data_method="zero_centered",
                    only_defective_images=False,
                    augment_data=False,
                    seed=self.run_hparams.seed)

            else:
                LOGGER.log("Using Synthetic Data ...")

                return self.dataset.synth_dataset_fn(
                    batch_size=batch_size,
                    training=False,
                    input_shape=list(self.run_hparams.input_shape) +
                    [self.run_hparams.n_channels],
                    mask_shape=list(self.run_hparams.mask_shape) +
                    [self.run_hparams.n_channels],
                    num_threads=64,
                    use_gpu_prefetch=True,
                    normalize_data_method="zero_centered",
                    only_defective_images=False,
                    augment_data=False,
                    seed=self.run_hparams.seed)

        model = self._get_estimator(mode='validation',
                                    run_params=estimator_params,
                                    use_xla=self.use_xla)

        try:
            eval_results = model.evaluate(
                input_fn=evaluation_data_fn,
                steps=num_steps,
                hooks=evaluation_hooks,
            )

            LOGGER.log('Ending Model Evaluation ...')

            LOGGER.log(
                '###################################\n\nEvaluation Results:\n')

            for key, val in sorted(eval_results.items(),
                                   key=operator.itemgetter(0)):

                if any(val in key
                       for val in ["loss", "global_step", "Confusion_Matrix"]):
                    continue

                LOGGER.log('%s: %.3f' % (key, float(val)))

            fns = eval_results["Confusion_Matrix_FN"]
            fps = eval_results["Confusion_Matrix_FP"]
            tns = eval_results["Confusion_Matrix_TN"]
            tps = eval_results["Confusion_Matrix_TP"]

            positives = np.add(tps, fns)
            negatives = np.add(tns, fps)

            tpr = np.divide(tps, positives)
            tnr = np.divide(tns, negatives)

            LOGGER.log('TP', tps)
            LOGGER.log('FN', fns)
            LOGGER.log('TN', tns)
            LOGGER.log('FP', tps)
            LOGGER.log('TPR', tpr)
            LOGGER.log('TNR', tnr)

            if save_eval_results_to_json:

                results_dict = {
                    'IoU': {
                        '0.75': str(eval_results["IoU_THS_0.75"]),
                        '0.85': str(eval_results["IoU_THS_0.85"]),
                        '0.95': str(eval_results["IoU_THS_0.95"]),
                        '0.99': str(eval_results["IoU_THS_0.99"]),
                    },
                    'TPR': {
                        '0.75': str(tpr[-4]),
                        '0.85': str(tpr[-3]),
                        '0.95': str(tpr[-2]),
                        '0.99': str(tpr[-1]),
                    },
                    'TNR': {
                        '0.75': str(tnr[-4]),
                        '0.85': str(tnr[-3]),
                        '0.95': str(tnr[-2]),
                        '0.99': str(tnr[-1]),
                    }
                }

                with open(
                        os.path.join(self.run_hparams.model_dir, "..",
                                     "results.json"), 'w') as f:
                    json.dump(results_dict, f)

        except KeyboardInterrupt:
            print("Keyboard interrupt")
Esempio n. 3
0
    def train(self,
              iter_unit,
              num_iter,
              batch_size,
              weight_decay,
              learning_rate,
              learning_rate_decay_factor,
              learning_rate_decay_steps,
              rmsprop_decay,
              rmsprop_momentum,
              use_auto_loss_scaling,
              augment_data,
              warmup_steps=50,
              is_benchmark=False):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError(
                '`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])'
                % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for training!')

        if self.run_hparams.amp:
            if use_auto_loss_scaling:

                if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                    print(
                        "TF Loss Auto Scaling is activated - Experimental Feature"
                    )

                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"
                apply_manual_loss_scaling = False

            else:
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "0"
                apply_manual_loss_scaling = True
        else:
            apply_manual_loss_scaling = False

        global_batch_size = batch_size * self.num_gpus

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs = self.dataset.get_dataset_runtime_specs(
                training=True,
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=global_batch_size)

            steps_per_epoch = int(num_steps / num_epochs)

        else:
            num_epochs = 1
            num_steps = num_iter
            steps_per_epoch = 625

        training_hooks = []

        if hvd_utils.is_using_hvd():
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
            training_hooks.append(
                ProfilerHook(global_batch_size=global_batch_size,
                             log_every=self.run_hparams.log_every_n_steps,
                             warmup_steps=warmup_steps,
                             is_training=True,
                             sample_dir=self.run_hparams.sample_dir))

            print("Starting Model Training ...")

            Logger.log(step=('PARAMETER'),
                       data={"Epochs": num_epochs},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"Total Steps": num_steps},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"Steps per Epoch": steps_per_epoch},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"Weight Decay Factor": weight_decay},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"Learning Rate": learning_rate},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={
                           "Learning Rate Decay Factor":
                           learning_rate_decay_factor
                       },
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(
                step=('PARAMETER'),
                data={"Learning Rate Decay Steps": learning_rate_decay_steps},
                verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"RMSProp - Decay": rmsprop_decay},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"RMSProp - Momentum": rmsprop_momentum},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(
                step=('PARAMETER'),
                data={"Loss Function Name": self.run_hparams.loss_fn_name},
                verbosity=Logger.Verbosity.DEFAULT)

            if self.run_hparams.amp:
                Logger.log(
                    step=('PARAMETER'),
                    data={"Use Auto Loss Scaling": use_auto_loss_scaling},
                    verbosity=Logger.Verbosity.DEFAULT)

            Logger.log(step=('PARAMETER'),
                       data={"# GPUs": self.num_gpus},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"GPU Batch Size": batch_size},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={"Global Batch Size": global_batch_size},
                       verbosity=Logger.Verbosity.DEFAULT)
            Logger.log(step=('PARAMETER'),
                       data={
                           "Total Files to be Processed":
                           num_steps * global_batch_size
                       },
                       verbosity=Logger.Verbosity.DEFAULT)

            print()  # visual spacing

        estimator_params = {
            'batch_size': batch_size,
            'steps_per_epoch': steps_per_epoch,
            'learning_rate': learning_rate,
            'learning_rate_decay_steps': learning_rate_decay_steps,
            'learning_rate_decay_factor': learning_rate_decay_factor,
            'rmsprop_decay': rmsprop_decay,
            'rmsprop_momentum': rmsprop_momentum,
            'weight_decay': weight_decay,
            'apply_manual_loss_scaling': apply_manual_loss_scaling,
            'loss_fn_name': self.run_hparams.loss_fn_name,
            'debug_verbosity': self.run_hparams.debug_verbosity,
        }

        def training_data_fn():

            if not is_benchmark or self.run_hparams.data_dir is not None:

                return self.dataset.dataset_fn(
                    batch_size=batch_size,
                    training=True,
                    only_defective_images=True,
                    augment_data=augment_data,
                    input_shape=list(self.run_hparams.input_shape) +
                    [self.run_hparams.n_channels],
                    mask_shape=list(self.run_hparams.mask_shape) +
                    [self.run_hparams.n_channels],
                    num_threads=64,
                    use_gpu_prefetch=True,
                    normalize_data_method="zero_centered",
                    seed=self.run_hparams.seed)

            else:
                if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                    print("Using Synthetic Data ...")

                return self.dataset.synth_dataset_fn(
                    batch_size=batch_size,
                    training=True,
                    input_shape=list(self.run_hparams.input_shape) +
                    [self.run_hparams.n_channels],
                    mask_shape=list(self.run_hparams.mask_shape) +
                    [self.run_hparams.n_channels],
                    num_threads=64,
                    use_gpu_prefetch=True,
                    normalize_data_method="zero_centered",
                    only_defective_images=True,
                    augment_data=augment_data,
                    seed=self.run_hparams.seed)

        model = self._get_estimator(mode='train',
                                    run_params=estimator_params,
                                    xla=self.xla)

        try:
            model.train(
                input_fn=training_data_fn,
                steps=num_steps,
                hooks=training_hooks,
            )
        except KeyboardInterrupt:
            print("Keyboard interrupt")

        if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
            print('Ending Model Training ...')
    def train(self,
              iter_unit,
              num_iter,
              batch_size,
              weight_decay,
              learning_rate,
              learning_rate_decay_factor,
              learning_rate_decay_steps,
              rmsprop_decay,
              rmsprop_momentum,
              use_auto_loss_scaling,
              augment_data,
              warmup_steps=50,
              is_benchmark=False):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError(
                '`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])'
                % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for training!')

        if self.run_hparams.use_tf_amp:
            if use_auto_loss_scaling:

                if not hvd_utils.is_using_hvd() or hvd.local_rank() == 0:
                    LOGGER.log(
                        "TF Loss Auto Scaling is activated - Experimental Feature"
                    )

                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"
                apply_manual_loss_scaling = False

            else:
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "0"
                apply_manual_loss_scaling = True
        else:
            apply_manual_loss_scaling = False

        if not hvd_utils.is_using_hvd() or hvd.local_rank() == 0:
            LOGGER.log('Defining Model Estimator ...\n')

        global_batch_size = batch_size * self.num_gpus

        if self.run_hparams.data_dir is not None:
            filenames, num_samples, num_steps, num_epochs = self.dataset.get_dataset_runtime_specs(
                training=True,
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=global_batch_size)

            steps_per_epoch = int(num_steps / num_epochs)

        else:
            num_epochs = 1
            num_steps = num_iter
            steps_per_epoch = 625

        training_hooks = []

        if hvd_utils.is_using_hvd():
            training_hooks.append(hvd.BroadcastGlobalVariablesHook(0))

        if not hvd_utils.is_using_hvd() or hvd.local_rank() == 0:
            training_hooks.append(
                ProfilerHook(global_batch_size=global_batch_size,
                             log_every=self.run_hparams.log_every_n_steps,
                             warmup_steps=warmup_steps,
                             is_training=True,
                             sample_dir=self.run_hparams.sample_dir))

            LOGGER.log('Starting Model Training ...\n')

            LOGGER.log("=> Epochs: %d" % num_epochs)
            LOGGER.log("=> Total Steps: %d" % num_steps)
            LOGGER.log("=> Steps per Epoch: %d" % steps_per_epoch)
            LOGGER.log("=> Weight Decay Factor: %.1e" % weight_decay)
            LOGGER.log("=> Learning Rate: %.1e" % learning_rate)
            LOGGER.log("=> Learning Rate Decay Factor: %.2f" %
                       learning_rate_decay_factor)
            LOGGER.log("=> Learning Rate Decay Steps: %d" %
                       learning_rate_decay_steps)
            LOGGER.log("=> RMSProp - Decay: %.1f" % rmsprop_decay)
            LOGGER.log("=> RMSProp - Momentum: %.1f" % rmsprop_momentum)
            LOGGER.log("=> Loss Function Name: %s" %
                       self.run_hparams.loss_fn_name)

            if self.run_hparams.use_tf_amp:
                LOGGER.log("=> Use Auto Loss Scaling: %s" %
                           use_auto_loss_scaling)

            LOGGER.log("=> # GPUs: %d" % self.num_gpus)
            LOGGER.log("=> GPU Batch Size: %d" % batch_size)
            LOGGER.log("=> Global Batch Size: %d" % global_batch_size)
            LOGGER.log("=> Total Files to Processed: %d\n" %
                       (num_steps * global_batch_size))

        estimator_params = {
            'batch_size': batch_size,
            'steps_per_epoch': steps_per_epoch,
            'learning_rate': learning_rate,
            'learning_rate_decay_steps': learning_rate_decay_steps,
            'learning_rate_decay_factor': learning_rate_decay_factor,
            'rmsprop_decay': rmsprop_decay,
            'rmsprop_momentum': rmsprop_momentum,
            'weight_decay': weight_decay,
            'apply_manual_loss_scaling': apply_manual_loss_scaling,
            'loss_fn_name': self.run_hparams.loss_fn_name,
            'debug_verbosity': self.run_hparams.debug_verbosity,
        }

        def training_data_fn():

            if not is_benchmark or self.run_hparams.data_dir is not None:

                return self.dataset.dataset_fn(
                    batch_size=batch_size,
                    training=True,
                    only_defective_images=True,
                    augment_data=augment_data,
                    input_shape=list(self.run_hparams.input_shape) +
                    [self.run_hparams.n_channels],
                    mask_shape=list(self.run_hparams.mask_shape) +
                    [self.run_hparams.n_channels],
                    num_threads=64,
                    use_gpu_prefetch=True,
                    normalize_data_method="zero_centered",
                    seed=self.run_hparams.seed)

            else:
                if not hvd_utils.is_using_hvd() or hvd.local_rank() == 0:
                    LOGGER.log("Using Synthetic Data ...")

                return self.dataset.synth_dataset_fn(
                    batch_size=batch_size,
                    training=True,
                    input_shape=list(self.run_hparams.input_shape) +
                    [self.run_hparams.n_channels],
                    mask_shape=list(self.run_hparams.mask_shape) +
                    [self.run_hparams.n_channels],
                    num_threads=64,
                    use_gpu_prefetch=True,
                    normalize_data_method="zero_centered",
                    only_defective_images=True,
                    augment_data=augment_data,
                    seed=self.run_hparams.seed)

        model = self._get_estimator(mode='train',
                                    run_params=estimator_params,
                                    use_xla=self.use_xla)

        try:
            model.train(
                input_fn=training_data_fn,
                steps=num_steps,
                hooks=training_hooks,
            )
        except KeyboardInterrupt:
            print("Keyboard interrupt")

        if not hvd_utils.is_using_hvd() or hvd.local_rank() == 0:
            LOGGER.log('Ending Model Training ...')