def _calculate_mean_and_var(self, x, axes, keep_dims):

        with ops.name_scope('moments', values=[x, axes]):
            # The dynamic range of fp16 is too limited to support the collection of
            # sufficient statistics. As a workaround we simply perform the operations
            # on 32-bit floats before converting the mean and variance back to fp16
            y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x

            if horovod_enabled():
                num_shards = hvd.size()
            else:
                num_shards = 1

            if num_shards > 1:
                local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
                local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes, keepdims=True)
                batch_size = math_ops.cast(array_ops.shape_v2(y)[0], dtypes.float32)
                # y_sum, y_squared_sum, global_batch_size = (
                #     replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, [
                #         local_sum, local_squared_sum, batch_size]))

                # hvd_info(f'local_sum {local_sum.shape}, local_squared_sum {local_squared_sum.shape}')

                y_sum = hvd.allreduce(local_sum, average=False)
                y_squared_sum = hvd.allreduce(local_squared_sum, average=False)

                global_batch_size = batch_size * num_shards
                axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))]
                multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), dtypes.float32)
                multiplier = multiplier * global_batch_size

                mean = y_sum / multiplier
                y_squared_mean = y_squared_sum / multiplier
                # var = E(x^2) - E(x)^2
                variance = y_squared_mean - math_ops.square(mean)
            else:
                # Compute true mean while keeping the dims for proper broadcasting.
                mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean')
                # sample variance, not unbiased variance
                # Note: stop_gradient does not change the gradient that gets
                #       backpropagated to the mean from the variance calculation,
                #       because that gradient is zero
                variance = math_ops.reduce_mean(
                    math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
                    axes,
                    keepdims=True,
                    name='variance')
            if not keep_dims:
                mean = array_ops.squeeze(mean, axes)
                variance = array_ops.squeeze(variance, axes)
            if x.dtype == dtypes.float16:
                return (math_ops.cast(mean, dtypes.float16),
                        math_ops.cast(variance, dtypes.float16))
            else:
                return (mean, variance)
Esempio n. 2
0
        def update(accum_vars):
            with tf.control_dependencies([global_step.assign(new_global_step)
                                          ]):
                if allreduce_post_accumulation and horovod_enabled():
                    accum_vars = [
                        hvd.allreduce(tf.convert_to_tensor(value=accum_var))
                        if isinstance(accum_var, tf.IndexedSlices) else
                        hvd.allreduce(accum_var) for accum_var in accum_vars
                    ]

                return optimizer.apply_gradients(list(zip(accum_vars, tvars)),
                                                 global_step=global_step)
Esempio n. 3
0
  def _moments(self, inputs, reduction_axes, keep_dims):
    """Compute the mean and variance: it overrides the original _moments."""
    shard_mean, shard_variance = super(SyncBatchNormalization, self)._moments(
      inputs, reduction_axes, keep_dims=keep_dims)

    num_shards = hvd.size() if horovod_enabled() else 1
    if num_shards > 1:
      # Compute variance using: Var[X]= E[X^2] - E[X]^2.
      shard_square_of_mean = tf.math.square(shard_mean)
      shard_mean_of_square = shard_variance + shard_square_of_mean
      group_mean = hvd.allreduce(shard_mean)
      group_mean_of_square = hvd.allreduce(shard_mean_of_square)
      group_variance = group_mean_of_square - tf.math.square(group_mean)
      return (group_mean, group_variance)
    else:
      return (shard_mean, shard_variance)
 def __init__(self, num_accumulation_steps=1):
     super(TrainableVarsAllreducingHookPreOpt, self).__init__()
     # Modify this collection in order to allreduce other set of variables
     trainable_vars = tf.compat.v1.trainable_variables()
     allreduced_trainable_var_ops = [
         v.assign(hvd.allreduce(v)) for v in trainable_vars
     ]
     self.allreduce_trainable_vars_op = tf.group(
         *allreduced_trainable_var_ops)
     self.num_accumulation_steps = num_accumulation_steps
     self.current_iteration = 1
Esempio n. 5
0
  def eval_end(self):
    """See base class."""

    if self.flags_obj.use_distributed_eval and horovod_enabled():
      test_accuracy = hvd.allreduce(self.test_accuracy.result())
    else:
      test_accuracy = self.test_accuracy.result()

    return {
        'test_loss': self.test_loss.result(),
        'test_accuracy': test_accuracy
    }
Esempio n. 6
0
    def begin(self):
        if self._use_all_reduce:
            self._avg_ops = OrderedDict({
                '{}'.format(tag): hvd.allreduce(
                    basic_session_run_hooks._as_graph_element(tensor))
                for (tag, tensor) in self._named_tensor.items()
            })
        else:
            self._avg_ops = OrderedDict({
                '{}'.format(tag):
                basic_session_run_hooks._as_graph_element(tensor)
                for (tag, tensor) in self._named_tensor.items()
            })

        self._global_step_tensor = tf.train.get_or_create_global_step()
        self._avg_ops['step'] = self._global_step_tensor
Esempio n. 7
0
def resnet_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for ResNet Models.

  Args:
    flags_obj: An object containing parsed flags. See define_resnet_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.

  Returns:
     Dict of results of the run.  Contains the keys `eval_results` and
    `train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5.
    `train_hooks` is a list the instances of hooks used during training.
  """

  experimental_preloading = flags_obj.experimental_preloading

  model_helpers.apply_clean(flags.FLAGS)

  # Ensures flag override logic is only executed if explicitly triggered.
  if flags_obj.tf_gpu_thread_mode:
    override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

  # Configures cluster spec for distribution strategy.
  num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts,
                                                     flags_obj.task_index)

  # Creates session config. allow_soft_placement = True, is required for
  # multi-GPU and is not harmful for other modes.
  session_config = tf.compat.v1.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=not experimental_preloading)

  if horovod_enabled():
    # The Scoped Allocator Optimization is enabled by default unless disabled by a flag.
    if not condition_env_var('TF_DISABLE_SCOPED_ALLOCATOR', default=False):
      from tensorflow.core.protobuf import rewriter_config_pb2  # pylint: disable=import-error
      session_config.graph_options.rewrite_options.scoped_allocator_optimization = rewriter_config_pb2.RewriterConfig.ON
      enable_op = session_config.graph_options.rewrite_options.scoped_allocator_opts.enable_op
      del enable_op[:]
      enable_op.append("HorovodAllreduce")

  distribution_strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=flags_obj.distribution_strategy,
      num_gpus=flags_core.get_num_gpus(flags_obj),
      num_workers=num_workers,
      all_reduce_alg=flags_obj.all_reduce_alg,
      num_packs=flags_obj.num_packs)

  # Creates a `RunConfig` that checkpoints every 24 hours which essentially
  # results in checkpoints determined only by `epochs_between_evals`.
  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy,
      session_config=session_config,
      log_step_count_steps=flags_obj.display_steps,
      save_checkpoints_secs=None,
      save_checkpoints_steps=flags_obj.save_checkpoint_steps)

  # Initializes model with all but the dense layer from pretrained ResNet.
  # if flags_obj.pretrained_model_checkpoint_path is not None:
  #   warm_start_settings = tf.estimator.WarmStartSettings(
  #       flags_obj.pretrained_model_checkpoint_path,
  #       vars_to_warm_start='^(?!.*dense)')
  # else:
  #   warm_start_settings = None
  warm_start_settings = None

  model_dir=flags_obj.model_dir

  if horovod_enabled():
    model_dir="{}/rank_{}".format(flags_obj.model_dir, hvd.rank())

  if experimental_preloading:
    SelectedEstimator = HabanaEstimator
  else:
    SelectedEstimator = tf.estimator.Estimator

  if flags.FLAGS.is_mlperf_enabled:
    for eval_batch_size in range(flags_obj.batch_size, 1, -1):
      if imagenet_main.NUM_IMAGES['validation'] % eval_batch_size == 0:
        break
  else:
    eval_batch_size = flags_obj.batch_size

  classifier = SelectedEstimator(
      model_fn=model_function, model_dir=model_dir, config=run_config,
      warm_start_from=warm_start_settings, params={
          'resnet_size': int(flags_obj.resnet_size),
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'resnet_version': int(flags_obj.resnet_version),
          'model_type': flags_obj.model_type,
          'loss_scale': flags_core.get_loss_scale(flags_obj,
                                                  default_for_fp16=128),
          'dtype': flags_core.get_tf_dtype(flags_obj),
          'fine_tune': flags_obj.fine_tune,
          'num_workers': num_workers,
          'train_epochs': flags_obj.train_epochs,
          'warmup_epochs': flags_obj.warmup_epochs,
          'use_cosine_lr': flags_obj.use_cosine_lr,
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'resnet_size': flags_obj.resnet_size,
      'resnet_version': flags_obj.resnet_version,
      'model_type': flags_obj.model_type,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
      'num_workers': num_workers,
  }
  if flags.FLAGS.is_mlperf_enabled:
    run_params['eval_batch_size'] = eval_batch_size

  if flags_obj.use_synthetic_data:
    dataset_name = dataset_name + '-synthetic'

  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info('resnet', dataset_name, run_params,
                                test_id=flags_obj.benchmark_test_id)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      model_dir=model_dir,
      batch_size=flags_obj.batch_size)

  if flags.FLAGS.is_mlperf_enabled:
    _log_cache = []
    def formatter(x):
      """Abuse side effects to get tensors out of the model_fn."""
      if _log_cache:
        _log_cache.pop()
        _log_cache.append(x.copy())
        return str(x)

    compliance_hook = tf.estimator.LoggingTensorHook(
      tensors={_NUM_EXAMPLES_NAME: _NUM_EXAMPLES_NAME},
      every_n_iter=int(1e10),
      at_end=True,
      formatter=formatter)
  else:
    compliance_hook = None

  if horovod_enabled():

    if "tf_profiler_hook" not in flags_obj.hooks and os.environ.get("TF_RANGE_TRACE", False):
      from TensorFlow.common.utils import RangeTFProfilerHook
      begin = (imagenet_main.NUM_IMAGES["train"] // (flags_obj.batch_size * hvd.size()) + 100)
      train_hooks.append(RangeTFProfilerHook(begin,20, "./rank-{}".format(hvd.rank())))

    if "synapse_logger_hook" not in flags_obj.hooks and "range" == os.environ.get("HABANA_SYNAPSE_LOGGER", "False").lower():
      from TensorFlow.common.horovod_helpers import SynapseLoggerHook
      begin = (imagenet_main.NUM_IMAGES["train"] // (flags_obj.batch_size * hvd.size()) + 100)
      end = begin + 100
      print("Begin: {}".format(begin))
      print("End: {}".format(end))
      train_hooks.append(SynapseLoggerHook(list(range(begin, end)), False))
    train_hooks.append(hvd.BroadcastGlobalVariablesHook(0))


  def input_fn_train(num_epochs, input_context=None):
    return input_function(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_replica_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=num_epochs,
        dtype=flags_core.get_dl_type(flags_obj),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        input_context=input_context, experimental_preloading=experimental_preloading)

  def input_fn_eval():
    return input_function(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_replica_batch_size(
            eval_batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1,
        dtype=flags_core.get_dl_type(flags_obj), experimental_preloading=experimental_preloading)

  train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else
                  flags_obj.train_epochs)

  max_train_steps = flags_obj.max_train_steps
  global_batch_size = flags_obj.batch_size * (hvd.size() if horovod_enabled() else 1)
  steps_per_epoch = (imagenet_main.NUM_IMAGES['train'] // global_batch_size)
  if max_train_steps is None:
    max_train_steps = steps_per_epoch * (train_epochs + flags_obj.train_offset)

  max_eval_steps = flags_obj.max_eval_steps
  if max_eval_steps is None:
    max_eval_steps = (imagenet_main.NUM_IMAGES['validation'] + eval_batch_size - 1) // eval_batch_size

  use_train_and_evaluate = flags_obj.use_train_and_evaluate or num_workers > 1
  if use_train_and_evaluate:
    train_spec = tf.estimator.TrainSpec(
        input_fn=lambda input_context=None: input_fn_train(
            train_epochs, input_context=input_context),
        hooks=train_hooks,
        max_steps=max_train_steps)
    eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval)
    tf.compat.v1.logging.info('Starting to train and evaluate.')
    tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
    # tf.estimator.train_and_evalute doesn't return anything in multi-worker
    # case.
    eval_results = {}
  else:
    if train_epochs == 0:
      # If --eval_only is set, perform a single loop with zero train epochs.
      schedule, n_loops = [0], 1
    else:
      # Compute the number of times to loop while training. All but the last
      # pass will train for `epochs_between_evals` epochs, while the last will
      # train for the number needed to reach `training_epochs`. For instance if
      #   train_epochs = 25 and epochs_between_evals = 10
      # schedule will be set to [10, 10, 5]. That is to say, the loop will:
      #   Train for 10 epochs and then evaluate.
      #   Train for another 10 epochs and then evaluate.
      #   Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
      n_loops = math.ceil(train_epochs / flags_obj.epochs_between_evals)
      schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))]
      schedule[-1] = train_epochs - sum(schedule[:-1])  # over counting.

    if flags.FLAGS.is_mlperf_enabled:
      mllogger.event(key=mllog.constants.CACHE_CLEAR)
      mllogger.start(key=mllog.constants.RUN_START)
      mllogger.event(key=mllog.constants.GLOBAL_BATCH_SIZE,
                     value=global_batch_size)

    final_step = 0

    if flags.FLAGS.is_mlperf_enabled:
      success = False
      if flags_obj.train_offset > 0:
        final_step += flags_obj.train_offset * steps_per_epoch
        mllogger.event(key=mllog.constants.FIRST_EPOCH_NUM, value=1, metadata={'number of epochs before main loop: ': flags_obj.train_offset})
        for i in range(flags_obj.train_offset):
          mllogger.event(key=mllog.constants.EPOCH_NUM, value=i+1)
        classifier.train(
              input_fn=lambda input_context=None: input_fn_train(
              flags_obj.train_offset, input_context=input_context),
              hooks=train_hooks + [compliance_hook],
              max_steps=max_train_steps if max_train_steps < final_step else final_step)

    for cycle_index, num_train_epochs in enumerate(schedule):
      tf.compat.v1.logging.info('Starting cycle: %d/%d', cycle_index,
                                int(n_loops))
      if flags.FLAGS.is_mlperf_enabled:
        mllogger.start(key=mllog.constants.BLOCK_START, value=cycle_index+1)
        mllogger.event(key=mllog.constants.FIRST_EPOCH_NUM, value=cycle_index*flags_obj.epochs_between_evals + flags_obj.train_offset + 1)
        mllogger.event(key=mllog.constants.EPOCH_COUNT, value=flags_obj.epochs_between_evals)

        for j in range(flags_obj.epochs_between_evals):
          mllogger.event(key=mllog.constants.EPOCH_NUM,
                         value=cycle_index  * flags_obj.epochs_between_evals + j +  flags_obj.train_offset + 1)

      if num_train_epochs:
        # Since we are calling classifier.train immediately in each loop, the
        # value of num_train_epochs in the lambda function will not be changed
        # before it is used. So it is safe to ignore the pylint error here
        # pylint: disable=cell-var-from-loop
        final_step += num_train_epochs * steps_per_epoch
        classifier.train(
            input_fn=lambda input_context=None: input_fn_train(
                num_train_epochs, input_context=input_context),
            hooks=train_hooks + [compliance_hook] if compliance_hook is not None else train_hooks,
            max_steps=max_train_steps if max_train_steps < final_step else final_step)
        if flags.FLAGS.is_mlperf_enabled:
            mllogger.end(key=mllog.constants.BLOCK_STOP, value=cycle_index+1)

      if flags.FLAGS.is_mlperf_enabled:
        mllogger.start(key=mllog.constants.EVAL_START)
      # max_eval_steps is associated with testing and profiling.
      # As a result it is frequently called with synthetic data,
      # which will iterate forever. Passing steps=max_eval_steps
      # allows the eval (which is generally unimportant in those circumstances)
      # to terminate. Note that eval will run for max_eval_steps each loop,
      # regardless of the global_step count.
      if flags_obj.get_flag_value("return_before_eval", False):
        return {}
      if flags_obj.get_flag_value("disable_eval", False):
        eval_results = None
        continue
      tf.compat.v1.logging.info('Starting to evaluate.')
      eval_results = classifier.evaluate(input_fn=input_fn_eval,
                                         steps=max_eval_steps)

      if flags.FLAGS.is_mlperf_enabled:
        mllogger.event(key=mllog.constants.EVAL_SAMPLES, value=int(eval_results[_NUM_EXAMPLES_NAME]))
        valdiation_epoch = (cycle_index + 1) * flags_obj.epochs_between_evals + flags_obj.train_offset
        mllogger.event(key=mllog.constants.EVAL_ACCURACY, value=float(eval_results['accuracy']), metadata={'epoch_num: ': valdiation_epoch})
        mllogger.end(key=mllog.constants.EVAL_STOP, metadata={'epoch_num: ' : valdiation_epoch})
        if flags_obj.stop_threshold:
          success = bool(eval_results['accuracy'] >= flags_obj.stop_threshold)

      benchmark_logger.log_evaluation_result(eval_results)

      if flags_obj.stop_threshold:
        if horovod_enabled():
          past_treshold = tf.cast(model_helpers.past_stop_threshold(
              flags_obj.stop_threshold, eval_results['accuracy']), tf.float32)
          global_past_treshold = tf.math.greater(
              hvd.allreduce(past_treshold, op=hvd.Sum), tf.zeros(1, tf.float32))
          if global_past_treshold.eval(session=tf.compat.v1.Session()):
            break
        else:
          if model_helpers.past_stop_threshold(
              flags_obj.stop_threshold, eval_results['accuracy']):
            break

  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    export_dtype = flags_core.get_tf_dtype(flags_obj)
    if flags_obj.image_bytes_as_serving_input:
      input_receiver_fn = functools.partial(
          image_bytes_serving_input_fn, shape, dtype=export_dtype)
    else:
      input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
          shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
                                 strip_default_attrs=True)

  stats = {}
  stats['eval_results'] = eval_results
  stats['train_hooks'] = train_hooks

  if flags.FLAGS.is_mlperf_enabled:
    mllogger.event(key=mllog.constants.RUN_STOP, value={"success": success})
    mllogger.end(key=mllog.constants.RUN_STOP)

  return stats
Esempio n. 8
0
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, manual_fp16=False, use_fp16=False, num_accumulation_steps=1,
                     optimizer_type="adam", allreduce_post_accumulation=False, init_loss_scale=2**32, weight_decay_rate=0.01,beta_1=0.9, beta_2=0.999, epsilon=1e-6,power = 0.5,use_tpu=False):
  """Creates an optimizer training op."""
  global_step = tf.compat.v1.train.get_or_create_global_step()
  # avoid step change in learning rate at end of warmup phase
  if optimizer_type == "adam":
    power = 1.0
    decayed_learning_rate_at_crossover_point = init_lr * (
        (1.0 - float(num_warmup_steps) / float(num_train_steps)) ** power)
  else:
    power = power
    decayed_learning_rate_at_crossover_point = init_lr
  adjusted_init_lr = init_lr * (init_lr / decayed_learning_rate_at_crossover_point)
  print('decayed_learning_rate_at_crossover_point = %e, adjusted_init_lr = %e' %
        (decayed_learning_rate_at_crossover_point, adjusted_init_lr))
  learning_rate = tf.constant(value=adjusted_init_lr, shape=[], dtype=tf.float32)
  # Implements linear decay of the learning rate.
  learning_rate = tf.compat.v1.train.polynomial_decay(
      learning_rate,
      global_step - 1,  ## We first update global_step, then apply_grad and thus we use global_step-1.
      num_train_steps,
      end_learning_rate=0.0,
      power=power,
      cycle=False)
  # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
  # learning rate will be `global_step/num_warmup_steps * init_lr`.
  if num_warmup_steps:
    global_steps_int = tf.cast(global_step, tf.int32)
    warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
    global_steps_float = tf.cast(global_steps_int, tf.float32)
    warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
    warmup_percent_done = global_steps_float / warmup_steps_float
    warmup_learning_rate = init_lr * warmup_percent_done
    is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
    learning_rate = (
        (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
  if optimizer_type == "lamb":
    print("Initializing LAMB Optimizer")
    optimizer = LAMBOptimizer(
        learning_rate=learning_rate,
        weight_decay_rate=weight_decay_rate,
        beta_1=beta_1,
        beta_2=beta_2,
        epsilon=epsilon,
        exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
  else:
    print("Initializing ADAM Weight Decay Optimizer")
    # It is recommended that you use this optimizer for fine tuning, since this
    # is how the model was trained (note that the Adam m/v variables are NOT
    # loaded from init_checkpoint.)
    optimizer = AdamWeightDecayOptimizer(
        learning_rate=learning_rate,
        weight_decay_rate=0.01,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-6,
        exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
  if horovod_enabled() and (num_accumulation_steps == 1 or (not allreduce_post_accumulation)):
    optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True)
  if use_fp16:
    loss_scaler = tf.train.experimental.DynamicLossScale(
        initial_loss_scale=init_loss_scale, increment_period=1000, multiplier=2.0)
    optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer, loss_scaler)
    loss_scale_value = tf.identity(loss_scaler(), name="loss_scale")
  if manual_fp16:
    assert False, "No support for ExponentialUpdateLossScaleManager and LossScaleOptimizer in TF2.0"
    loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(init_loss_scale=init_loss_scale,
                                                                                      incr_every_n_steps=1000,
                                                                                      decr_every_n_nan_or_inf=2,
                                                                                      decr_ratio=0.5)
    optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(optimizer, loss_scale_manager)
  if use_tpu:
    optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)
  tvars = tf.compat.v1.trainable_variables()
  if num_accumulation_steps > 1:
    #grads_and_vars = optimizer.compute_gradients(loss * 1.0 / num_accumulation_steps, tvars)
    ## to match mlcomm ref we need to clip before scaling
    grads_and_vars = optimizer.compute_gradients(loss , tvars, gate_gradients=tf.compat.v1.train.Optimizer.GATE_NONE)
    local_step = tf.compat.v1.get_variable(name="local_step", shape=[], dtype=tf.int32, trainable=False,
                                           initializer=tf.compat.v1.zeros_initializer)
    batch_finite = tf.compat.v1.get_variable(name="batch_finite", shape=[], dtype=tf.bool, trainable=False,
                                             initializer=tf.compat.v1.ones_initializer)
    accum_vars = [tf.compat.v1.get_variable(
        name=tvar.name.split(":")[0] + "/accum",
        shape=tvar.shape.as_list(),
        dtype=tf.float32,
        trainable=False,
        initializer=tf.compat.v1.zeros_initializer()) for tvar in tf.compat.v1.trainable_variables()]
    reset_step = tf.cast(tf.math.equal(local_step % num_accumulation_steps, 0), dtype=tf.bool)
    local_step = tf.cond(pred=reset_step, true_fn=lambda: local_step.assign(
        tf.ones_like(local_step)), false_fn=lambda: local_step.assign_add(1))
    grads_and_vars_and_accums = [(gv[0], gv[1], accum_vars[i])
                                 for i, gv in enumerate(grads_and_vars) if gv[0] is not None]
    grads, tvars, accum_vars = list(zip(*grads_and_vars_and_accums))
    all_are_finite = tf.reduce_all(input_tensor=[tf.reduce_all(input_tensor=tf.math.is_finite(
        g)) for g in grads]) if manual_fp16 or use_fp16 else tf.constant(True, dtype=tf.bool)
    batch_finite = tf.cond(pred=reset_step,
                           true_fn=lambda: batch_finite.assign(tf.math.logical_and(
                               tf.constant(True, dtype=tf.bool), all_are_finite)),
                           false_fn=lambda: batch_finite.assign(tf.math.logical_and(batch_finite, all_are_finite)))
    # This is how the model was pre-trained.
    # ensure global norm is a finite number
    # to prevent clip_by_global_norm from having a hizzy fit.
    (clipped_grads, _) = tf.clip_by_global_norm(
        grads, clip_norm=1.0,
        use_norm=tf.cond(
            pred=all_are_finite,
            true_fn=lambda: tf.linalg.global_norm(grads),
            false_fn=lambda: tf.constant(1.0)))
    ## divide grad by acc_steps before accumulating
    accum_vars = tf.cond(pred=reset_step,
                         true_fn=lambda: [accum_vars[i].assign(grad) for i, grad in enumerate(clipped_grads)],
                         false_fn=lambda:  [accum_vars[i].assign_add(grad) for i, grad in enumerate(clipped_grads)])
    update_step = tf.identity(tf.cast(tf.math.equal(local_step % num_accumulation_steps, 0),
                                      dtype=tf.bool), name="update_step")
    def allreduce_of_batch_finite_required():
      # In case of bf16 and fp32 batch finite is tf.constant(True, dtype=tf.bool)
      return horovod_enabled() and manual_fp16 and use_fp16
    # TODO: in future if we want to enable infinite batch iter skiping we will need to change this allreduce.
    new_global_step = tf.cond(pred=tf.math.logical_and(update_step,
                                                       tf.cast(hvd.allreduce(tf.cast(batch_finite, tf.int32)), tf.bool) if allreduce_of_batch_finite_required() else batch_finite),
                              true_fn=lambda: global_step + 1,
                              false_fn=lambda: global_step)
    new_global_step = tf.identity(new_global_step, name='step_update')
    def update(accum_vars):
      with tf.control_dependencies([global_step.assign(new_global_step)]):
        if allreduce_post_accumulation and horovod_enabled():
          accum_vars = [hvd.allreduce(tf.convert_to_tensor(value=accum_var)* 1.0 / num_accumulation_steps, op=hvd.Sum) if isinstance(accum_var, tf.IndexedSlices)
                        else hvd.allreduce(accum_var * 1.0 / num_accumulation_steps, op=hvd.Sum) for accum_var in accum_vars]
        return optimizer.apply_gradients(list(zip(accum_vars, tvars)), global_step=global_step)
    train_op = tf.cond(pred=update_step,
                       true_fn=lambda: update(accum_vars), false_fn=lambda: tf.no_op())
  else:
    grads_and_vars = optimizer.compute_gradients(loss, tvars, gate_gradients=tf.compat.v1.train.Optimizer.GATE_NONE)
    grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
    grads, tvars = list(zip(*grads_and_vars))
    all_are_finite = tf.reduce_all(
        input_tensor=[tf.reduce_all(input_tensor=tf.math.is_finite(g)) for g in grads]) if use_fp16 or manual_fp16 else tf.constant(True, dtype=tf.bool)

    # This is how the model was pre-trained.
    # ensure global norm is a finite number
    # to prevent clip_by_global_norm from having a hizzy fit.
    (clipped_grads, _) = tf.clip_by_global_norm(
        grads, clip_norm=1.0,
        use_norm=tf.cond(
            pred=all_are_finite,
            true_fn=lambda: tf.linalg.global_norm(grads),
            false_fn=lambda: tf.constant(1.0)))
    new_global_step = tf.cond(pred=all_are_finite, true_fn=lambda: global_step + 1, false_fn=lambda: global_step)
    new_global_step = tf.identity(new_global_step, name='step_update')
    with tf.control_dependencies([global_step.assign(new_global_step)]):
      train_op = optimizer.apply_gradients(
          list(zip(clipped_grads, tvars)), global_step=global_step)
  return train_op
    def eval_end(self):
        """See base class."""
        epoch_num = int(self.epoch_helper.current_epoch)
        self.mlperf_mlloger.end(key=self.mlperf_mllog.constants.EVAL_STOP,
                                value=None,
                                metadata={'epoch_num': epoch_num + 1})

        local_hit = self.test_accuracy.total
        local_count = self.test_accuracy.count

        global_hit = local_hit
        global_count = local_count
        if horovod_enabled() and self.dist_eval:
            global_hit = hvd.allreduce(local_hit, op=hvd.Sum)
            global_count = hvd.allreduce(local_count, op=hvd.Sum)
        global_accuracy = float(global_hit / global_count)

        # assign to self
        self.test_accuracy.total.assign(global_hit)
        self.test_accuracy.count.assign(global_count)

        eval_accuracy = global_accuracy
        self.eval_accuracy = eval_accuracy
        self.mlperf_mlloger.event(
            key=self.mlperf_mllog.constants.EVAL_ACCURACY,
            value=eval_accuracy,
            metadata={'epoch_num': epoch_num + 1})

        first_epoch_num = max(
            epoch_num - self.flags_obj.epochs_between_evals + 1, 0)
        epoch_count = self.flags_obj.epochs_between_evals
        if first_epoch_num == 0:
            epoch_count = self.flags_obj.eval_offset_epochs
            if epoch_count == 0:
                epoch_count = self.flags_obj.epochs_between_evals
        self.mlperf_mlloger.end(key=self.mlperf_mllog.constants.BLOCK_STOP,
                                value=None,
                                metadata={
                                    'first_epoch_num': first_epoch_num + 1,
                                    'epoch_count': epoch_count
                                })

        past_threshold = False
        if self.flags_obj.target_accuracy is not None:
            past_threshold = eval_accuracy >= self.flags_obj.target_accuracy
            if (horovod_enabled() and (not self.dist_eval)):
                past_threshold = hvd.allreduce(
                    tf.cast(past_threshold, tf.float32), op=hvd.Sum) > 0

        continue_training = True
        if past_threshold:
            continue_training = False
        elif ((not self.profile) and eval_accuracy <= 0.002):
            continue_training = False
        elif self.global_step.numpy() < self.train_steps:
            self.mlperf_mlloger.start(
                key=self.mlperf_mllog.constants.BLOCK_START,
                value=None,
                metadata={
                    'first_epoch_num': epoch_num + 2,
                    'epoch_count': self.flags_obj.epochs_between_evals
                })

        metrics = {
            'test_accuracy': eval_accuracy,
            'continue_training': continue_training,
        }
        if self.test_loss:
            metrics['test_loss'] = self.test_loss.result()
        return metrics