예제 #1
0
 def _default_run_finish_fn(success_status):
     if not success_status:
         mlp_log.mlperf_print("run_stop",
                              None,
                              metadata={"status": "failure"})
     tf.logging.info("Retrieving embedding vars and writing stats.")
     runner.retrieve_embedding_vars()
예제 #2
0
def learning_rate_schedule(current_epoch):
  """Handles linear scaling rule, gradual warmup, and LR decay.

  The learning rate starts at 0, then it increases linearly per step.
  After 5 epochs we reach the base learning rate (scaled to account
    for batch size).
  After 30, 60 and 80 epochs the learning rate is divided by 10.
  After 90 epochs training stops and the LR is set to 0. This ensures
    that we train for exactly 90 epochs for reproducibility.

  Args:
    current_epoch: `Tensor` for current epoch.

  Returns:
    A scaled `Tensor` for current learning rate.
  """
  mlp_log.mlperf_print('lars_opt_base_learning_rate', FLAGS.base_learning_rate)
  scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)

  decay_rate = (scaled_lr * LR_SCHEDULE[0][0] *
                current_epoch / LR_SCHEDULE[0][1])
  for mult, start_epoch in LR_SCHEDULE:
    decay_rate = tf.where(current_epoch < start_epoch,
                          decay_rate, scaled_lr * mult)
  return decay_rate
예제 #3
0
def poly_rate_schedule(current_epoch, poly_rate=0.0):
    """Handles linear scaling rule, gradual warmup, and LR decay.

  The learning rate starts at 0, then it increases linearly per step.  After
  FLAGS.poly_warmup_epochs, we reach the base learning rate (scaled to account
  for batch size). The learning rate is then decayed using a polynomial rate
  decay schedule with power 2.0.

  Args:
    current_epoch: `Tensor` for current epoch.
    poly_rate: Polynomial decay rate.

  Returns:
    A scaled `Tensor` for current learning rate.
  """

    batch_size = FLAGS.train_batch_size
    if batch_size < 16384:
        plr = 10.0
        w_epochs = 5
    elif batch_size < 32768:
        plr = 25.0
        w_epochs = 5
    else:
        plr = 31.2
        w_epochs = 25

    # Override default poly learning rate and warmup epochs
    if poly_rate > 0.0:
        plr = poly_rate

    if FLAGS.lars_base_learning_rate > 0.0:
        plr = FLAGS.lars_base_learning_rate

    if FLAGS.lars_warmup_epochs > 0:
        w_epochs = FLAGS.lars_warmup_epochs

    mlp_log.mlperf_print('lars_opt_base_learning_rate', plr)
    mlp_log.mlperf_print('lars_opt_learning_rate_warmup_epochs', w_epochs)
    end_lr = 0.0001
    mlp_log.mlperf_print('lars_opt_end_learning_rate', end_lr)

    wrate = (plr * current_epoch / w_epochs)
    w_steps = (w_epochs * FLAGS.num_train_images // batch_size)
    min_step = tf.constant(1, dtype=tf.int64)
    global_step = tf.train.get_or_create_global_step()
    decay_steps = tf.maximum(min_step, tf.subtract(global_step, w_steps))

    mlp_log.mlperf_print('lars_opt_learning_rate_decay_steps',
                         FLAGS.train_steps - w_steps + 1)
    mlp_log.mlperf_print('lars_opt_learning_rate_decay_poly_power', 2.0)

    poly_rate = tf.train.polynomial_decay(plr,
                                          decay_steps,
                                          FLAGS.train_steps - w_steps + 1,
                                          end_lr,
                                          power=2.0)
    decay_rate = tf.where(current_epoch <= w_epochs, wrate, poly_rate)
    return decay_rate
예제 #4
0
 def eval_init_fn(cur_step):
     """Executed before every eval."""
     steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
     cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch
     mlp_log.mlperf_print('block_start',
                          None,
                          metadata={
                              'first_epoch_num': cur_epoch,
                              'epoch_count': 1
                          })
예제 #5
0
 def eval_init_fn(cur_step):
   """Executed before every eval."""
   steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
   epoch = cur_step // steps_per_epoch
   mlp_log.mlperf_print(
       'block_start',
       None,
       metadata={
           'first_epoch_num': epoch,
           'epoch_count': 4
       })
 def eval_init_fn(cur_step):
     """Executed before every eval."""
     # While BERT pretraining does not have epochs,
     # to make the logging consistent with other mlperf models,
     # in all the mlp_log, epochs are steps, and examples are sequences.
     mlp_log.mlperf_print("block_start",
                          None,
                          metadata={
                              "first_epoch_num":
                              cur_step + FLAGS.iterations_per_loop,
                              "epoch_count": FLAGS.iterations_per_loop
                          })
예제 #7
0
    def _default_eval_finish_fn(cur_step, eval_output, summary_writer=None):
        eval_num = cur_step // FLAGS.steps_between_evals
        mlp_log.mlperf_print("eval_stop",
                             None,
                             metadata={"epoch_num": eval_num + 1})
        mlp_log.mlperf_print("block_stop",
                             None,
                             metadata={"first_epoch_num": eval_num + 1})
        tf.logging.info(
            "== Eval finished (step {}). Computing metric..".format(cur_step))

        results_np = np.array(eval_output["results"])
        results_np = np.reshape(results_np, (-1, 2))
        predictions_np = results_np[:, 0].astype(np.float32)
        targets_np = results_np[:, 1].astype(np.int32)
        roc_obj = roc_metrics.RocMetrics(predictions_np, targets_np)
        roc_auc = roc_obj.ComputeRocAuc()
        tf.logging.info("== Eval shape: {}.  AUC = {:.4f}".format(
            predictions_np.shape, roc_auc))
        success = roc_auc >= _ACCURACY_THRESH
        mlp_log.mlperf_print("eval_accuracy",
                             roc_auc,
                             metadata={"epoch_num": eval_num + 1})
        if success:
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "success"})
        if summary_writer:
            summary_writer.add_summary(
                utils.create_scalar_summary("auc", roc_auc),
                global_step=cur_step + FLAGS.steps_between_evals)
        eval_metrics.append((cur_step + FLAGS.steps_between_evals, roc_auc))
        return success
예제 #8
0
 def _default_eval_init_fn(cur_step):
     """Logging statements executed before every eval."""
     eval_num = cur_step // FLAGS.steps_between_evals
     tf.logging.info("== Block {}. Step {} of {}".format(
         eval_num + 1, cur_step, FLAGS.train_steps))
     mlp_log.mlperf_print("block_start",
                          None,
                          metadata={
                              "first_epoch_num": eval_num + 1,
                              "epoch_count": 1
                          })
     mlp_log.mlperf_print("eval_start",
                          None,
                          metadata={"epoch_num": eval_num + 1})
예제 #9
0
def init_lars_optimizer(current_epoch):
    """Initialize the LARS Optimizer."""

    lars_epsilon = FLAGS.lars_epsilon
    mlp_log.mlperf_print('lars_epsilon', lars_epsilon)

    learning_rate = poly_rate_schedule(current_epoch, FLAGS.poly_rate)
    optimizer = contrib_opt.LARSOptimizer(
        learning_rate,
        momentum=FLAGS.momentum,
        weight_decay=FLAGS.weight_decay,
        skip_list=['batch_normalization', 'bias'],
        epsilon=lars_epsilon)
    return optimizer
예제 #10
0
def learning_rate_schedule(peak_learning_rate, lr_warmup_init, lr_warmup_step,
                           first_lr_drop_step, second_lr_drop_step,
                           global_step):
    """Handles linear scaling rule, gradual warmup, and LR decay."""
    # lr_warmup_init is the starting learning rate; the learning rate is linearly
    # scaled up to the full learning rate after `lr_warmup_step` before decaying.
    mlp_log.mlperf_print(key='opt_learning_rate_decay_factor', value=0.1)
    mlp_log.mlperf_print('opt_learning_rate_decay_steps',
                         (first_lr_drop_step, second_lr_drop_step))
    linear_warmup = (lr_warmup_init +
                     (tf.cast(global_step, dtype=tf.float32) / lr_warmup_step *
                      (peak_learning_rate - lr_warmup_init)))
    learning_rate = tf.where(global_step < lr_warmup_step, linear_warmup,
                             peak_learning_rate)
    lr_schedule = [[1.0, lr_warmup_step], [0.1, first_lr_drop_step],
                   [0.01, second_lr_drop_step]]
    for mult, start_global_step in lr_schedule:
        learning_rate = tf.where(global_step < start_global_step,
                                 learning_rate, peak_learning_rate * mult)
    return learning_rate
예제 #11
0
  def eval_finish_fn(cur_step, eval_output, summary_writer):
    """Executed after every eval."""
    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    epoch = cur_step // steps_per_epoch
    eval_accuracy = float(np.sum(
        eval_output['total_correct'])) / FLAGS.num_eval_images

    if summary_writer:
      with tf.Graph().as_default():
        summary_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag='accuracy', simple_value=eval_accuracy)
            ]), cur_step)
    mlp_log.mlperf_print(
        'eval_accuracy',
        eval_accuracy,
        metadata={
            'epoch_num': epoch + FLAGS.iterations_per_loop // steps_per_epoch
        })
    mlp_log.mlperf_print(
        'block_stop',
        None,
        metadata={
            'first_epoch_num': epoch,
            'epoch_count': 4
        })
    if eval_accuracy >= FLAGS.stop_threshold:
      mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'})
      return True
    else:
      return False
예제 #12
0
    def eval_finish_fn(cur_step, eval_output, _):
        """Callback function that's executed after each eval."""
        if eval_steps == 0:
            return False
        # Concat eval_output as eval_output is a list from each host.
        for key in eval_output:
            eval_output[key] = np.concatenate(eval_output[key], axis=0)
        steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
        cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch
        mlp_log.mlperf_print('block_stop',
                             None,
                             metadata={
                                 'first_epoch_num': cur_epoch,
                                 'epoch_count': 1
                             })
        eval_multiprocess.eval_multiprocessing(
            eval_output, eval_metric, mask_rcnn_params.EVAL_WORKER_COUNT)

        mlp_log.mlperf_print('eval_start',
                             None,
                             metadata={'epoch_num': cur_epoch + 1})
        _, eval_results = eval_metric.evaluate()
        mlp_log.mlperf_print('eval_accuracy', {
            'BBOX': float(eval_results['AP']),
            'SEGM': float(eval_results['mask_AP'])
        },
                             metadata={'epoch_num': cur_epoch + 1})
        mlp_log.mlperf_print('eval_stop',
                             None,
                             metadata={'epoch_num': cur_epoch + 1})
        if (eval_results['AP'] >= mask_rcnn_params.BOX_EVAL_TARGET and
                eval_results['mask_AP'] >= mask_rcnn_params.MASK_EVAL_TARGET):
            mlp_log.mlperf_print('run_stop',
                                 None,
                                 metadata={'status': 'success'})
            return True
        return False
예제 #13
0
 def get_learning_rate(self, global_step):
     """Sets up learning rate schedule."""
     learning_rate = lr_policy.learning_rate_schedule(
         self.params['learning_rate'], self.params['lr_warmup_init'],
         self.params['lr_warmup_step'], self.params['first_lr_drop_step'],
         self.params['second_lr_drop_step'], global_step)
     mlp_log.mlperf_print(key='opt_base_learning_rate',
                          value=self.params['learning_rate'])
     mlp_log.mlperf_print(key='opt_learning_rate_warmup_steps',
                          value=self.params['lr_warmup_step'])
     mlp_log.mlperf_print(key='opt_learning_rate_warmup_factor',
                          value=self.params['learning_rate'] /
                          self.params['lr_warmup_step'])
     return learning_rate
        def infeed_thread_fn(sess, train_enqueue_ops, eval_enqueue_ops,
                             eval_init):
            """Start the infeed."""
            time.sleep(150)

            mlp_log.mlperf_print("init_stop", None)
            mlp_log.mlperf_print("run_start", None)
            mlp_log.mlperf_print("block_start",
                                 None,
                                 metadata={
                                     "first_epoch_num": 1,
                                     "epoch_count": 1
                                 })

            for i in range(self.hparams.max_train_epochs):
                tf.logging.info("Infeed for epoch: %d", i + 1)
                sess.run(eval_init)
                sess.run([train_enqueue_ops])
                sess.run([eval_enqueue_ops])
    def eval_finish_fn(cur_step, eval_output, summary_writer):
        """Executed after every eval."""
        global run_steps
        global masked_lm_accuracy
        cur_step_corrected = cur_step + FLAGS.iterations_per_loop
        run_steps = cur_step_corrected
        masked_lm_weighted_correct = eval_output["masked_lm_weighted_correct"]
        masked_lm_weighted_count = eval_output["masked_lm_weighted_count"]

        masked_lm_accuracy = np.sum(masked_lm_weighted_correct) / np.sum(
            masked_lm_weighted_count)
        # the eval_output may mix up the order of the two arrays
        # swap the order if it did got mix up
        if masked_lm_accuracy > 1:
            masked_lm_accuracy = 1 / masked_lm_accuracy

        if summary_writer:
            with tf.Graph().as_default():
                summary_writer.add_summary(
                    tf.Summary(value=[
                        tf.Summary.Value(tag="masked_lm_accuracy",
                                         simple_value=masked_lm_accuracy)
                    ]), cur_step_corrected)

        mlp_log.mlperf_print("block_stop",
                             None,
                             metadata={
                                 "first_epoch_num": cur_step_corrected,
                             })
        # While BERT pretraining does not have epochs,
        # to make the logging consistent with other mlperf models,
        # in all the mlp_log, epochs are steps, and examples are sequences.
        mlp_log.mlperf_print("eval_accuracy",
                             float(masked_lm_accuracy),
                             metadata={"epoch_num": cur_step_corrected})
        if (masked_lm_accuracy >= FLAGS.stop_threshold
                and cur_step_corrected >= FLAGS.iterations_per_loop * 6):
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "success"})
            return True
        else:
            return False
예제 #16
0
def main(unused_argv):

  def eval_init_fn(cur_step):
    """Executed before every eval."""
    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    epoch = cur_step // steps_per_epoch
    mlp_log.mlperf_print(
        'block_start',
        None,
        metadata={
            'first_epoch_num': epoch,
            'epoch_count': 4
        })

  def eval_finish_fn(cur_step, eval_output, summary_writer):
    """Executed after every eval."""
    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    epoch = cur_step // steps_per_epoch
    eval_accuracy = float(np.sum(
        eval_output['total_correct'])) / FLAGS.num_eval_images

    if summary_writer:
      with tf.Graph().as_default():
        summary_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag='accuracy', simple_value=eval_accuracy)
            ]), cur_step)
    mlp_log.mlperf_print(
        'eval_accuracy',
        eval_accuracy,
        metadata={
            'epoch_num': epoch + FLAGS.iterations_per_loop // steps_per_epoch
        })
    mlp_log.mlperf_print(
        'block_stop',
        None,
        metadata={
            'first_epoch_num': epoch,
            'epoch_count': 4
        })
    if eval_accuracy >= FLAGS.stop_threshold:
      mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'})
      return True
    else:
      return False

  def run_finish_fn(success):
    if not success:
      mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'})
    mlp_log.mlperf_print('run_final', None)

  low_level_runner = train_and_eval_runner.TrainAndEvalRunner(
      FLAGS.iterations_per_loop, FLAGS.train_steps,
      int(math.ceil(FLAGS.num_eval_images / FLAGS.eval_batch_size)),
      FLAGS.num_replicas)

  mlp_log.mlperf_print('cache_clear', True)
  mlp_log.mlperf_print('init_start', None)
  mlp_log.mlperf_print('global_batch_size', FLAGS.train_batch_size)
  mlp_log.mlperf_print('lars_opt_weight_decay', FLAGS.weight_decay)
  mlp_log.mlperf_print('lars_opt_momentum', FLAGS.momentum)
  mlp_log.mlperf_print('submission_benchmark', 'resnet')
  mlp_log.mlperf_print('submission_division', 'closed')
  mlp_log.mlperf_print('submission_org', 'google')
  mlp_log.mlperf_print('submission_platform', 'tpu-v3-%d' % FLAGS.num_replicas)
  mlp_log.mlperf_print('submission_status', 'research')

  assert FLAGS.precision == 'bfloat16' or FLAGS.precision == 'float32', (
      'Invalid value for --precision flag; must be bfloat16 or float32.')
  input_dtype = tf.bfloat16 if FLAGS.precision == 'bfloat16' else tf.float32
  cache_decoded_image = True if FLAGS.num_replicas > 2048 else False
  imagenet_train, imagenet_eval = [
      imagenet_input.get_input_fn(  # pylint: disable=g-complex-comprehension
          FLAGS.data_dir,
          is_training,
          input_dtype,
          FLAGS.image_size,
          FLAGS.input_partition_dims is None,
          cache_decoded_image=cache_decoded_image)
      for is_training in [True, False]
  ]

  low_level_runner.initialize(imagenet_train, imagenet_eval, resnet_model_fn,
                              FLAGS.train_batch_size, FLAGS.eval_batch_size,
                              FLAGS.input_partition_dims)

  mlp_log.mlperf_print('train_samples', FLAGS.num_train_images)
  mlp_log.mlperf_print('eval_samples', FLAGS.num_eval_images)
  mlp_log.mlperf_print('init_stop', None)
  mlp_log.mlperf_print('run_start', None)
  low_level_runner.train_and_eval(eval_init_fn, eval_finish_fn, run_finish_fn)
    def train_and_predict(self):
        """Run the predict loop on the TPU device."""
        self.sess.run([self.compile_op])

        # Train and eval thread.
        def train_eval_thread_fn(sess, train_eval_op):
            tf.logging.info("train_eval_op start")
            sess.run([train_eval_op])

        train_eval_thread = threading.Thread(target=train_eval_thread_fn,
                                             args=(self.sess,
                                                   self.train_eval_op))
        train_eval_thread.start()

        # Infeed thread.
        def infeed_thread_fn(sess, train_enqueue_ops, eval_enqueue_ops,
                             eval_init):
            """Start the infeed."""
            time.sleep(150)

            mlp_log.mlperf_print("init_stop", None)
            mlp_log.mlperf_print("run_start", None)
            mlp_log.mlperf_print("block_start",
                                 None,
                                 metadata={
                                     "first_epoch_num": 1,
                                     "epoch_count": 1
                                 })

            for i in range(self.hparams.max_train_epochs):
                tf.logging.info("Infeed for epoch: %d", i + 1)
                sess.run(eval_init)
                sess.run([train_enqueue_ops])
                sess.run([eval_enqueue_ops])

        infeed_thread = threading.Thread(target=infeed_thread_fn,
                                         args=(self.sess, self.enqueue_ops,
                                               self.eval_enqueue_ops,
                                               self.eval_dataset_initializer))
        infeed_thread.start()

        if self.eval_steps > 0:
            eval_state = {"run_success": False, "score": 0.0}

            for epoch in range(self.hparams.max_train_epochs):
                predictions = list(self.predict())
                mlp_log.mlperf_print("eval_start",
                                     None,
                                     metadata={"epoch_num": epoch + 1})
                current_step = epoch * self.iterations

                eval_state["score"] = metric.get_metric(
                    self.hparams, predictions, current_step)
                tf.logging.info("Score after epoch %d: %f", epoch,
                                eval_state["score"])
                mlp_log.mlperf_print("eval_accuracy",
                                     eval_state["score"] / 100.0,
                                     metadata={"epoch_num": epoch + 1})
                mlp_log.mlperf_print("eval_stop",
                                     None,
                                     metadata={"epoch_num": epoch + 1})
                mlp_log.mlperf_print("block_stop",
                                     None,
                                     metadata={
                                         "first_epoch_num": epoch + 1,
                                         "epoch_count": 1
                                     })
                if eval_state["score"] >= self.hparams.target_bleu:
                    eval_state["run_success"] = True
                    mlp_log.mlperf_print("run_stop",
                                         None,
                                         metadata={"status": "success"})
                    break
                mlp_log.mlperf_print("block_start",
                                     None,
                                     metadata={
                                         "first_epoch_num": epoch + 2,
                                         "epoch_count": 1
                                     })

            if not eval_state["run_success"]:
                mlp_log.mlperf_print("run_stop",
                                     None,
                                     metadata={"status": "abort"})

        infeed_thread.join()
        train_eval_thread.join()

        if self.eval_steps > 0:
            return eval_state["score"], current_step
        else:
            return None, None
예제 #18
0
def main(argv):
    del argv  # Unused.

    # TODO(b/132208296): remove this workaround that uses control flow v2.
    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True

    # Parse hparams
    hparams = mask_rcnn_params.default_hparams()
    hparams.parse(FLAGS.hparams)

    params = dict(
        hparams.values(),
        transpose_input=False
        if FLAGS.input_partition_dims is not None else True,
        resnet_checkpoint=FLAGS.resnet_checkpoint,
        val_json_file=FLAGS.val_json_file,
        num_cores_per_replica=int(np.prod(FLAGS.input_partition_dims))
        if FLAGS.input_partition_dims else 1,
        replicas_per_host=FLAGS.replicas_per_host)

    # MLPerf logging.
    mlp_log.mlperf_print(key='cache_clear', value=True)
    mlp_log.mlperf_print(key='init_start', value=None)
    mlp_log.mlperf_print(key='global_batch_size', value=FLAGS.train_batch_size)
    mlp_log.mlperf_print(key='train_samples',
                         value=FLAGS.num_examples_per_epoch)
    mlp_log.mlperf_print(key='eval_samples', value=FLAGS.eval_samples)
    mlp_log.mlperf_print(key='min_image_size',
                         value=params['short_side_image_size'])
    mlp_log.mlperf_print(key='max_image_size',
                         value=params['long_side_max_image_size'])
    mlp_log.mlperf_print(key='num_image_candidates',
                         value=params['rpn_post_nms_topn'])

    train_steps = (FLAGS.num_epochs * FLAGS.num_examples_per_epoch //
                   FLAGS.train_batch_size)
    eval_steps = int(
        math.ceil(float(FLAGS.eval_samples) / FLAGS.eval_batch_size))
    if eval_steps > 0:
        # The eval dataset is not evenly divided. Adding step by one will make sure
        # all eval samples are covered.
        # TODO(b/151732586): regenerate the eval dataset to make all hosts get the
        #                    same amount of work.
        eval_steps += 1
    runner = train_and_eval_runner.TrainAndEvalRunner(
        FLAGS.num_examples_per_epoch // FLAGS.train_batch_size, train_steps,
        eval_steps, FLAGS.num_shards)
    train_input_fn = dataloader.InputReader(FLAGS.training_file_pattern,
                                            mode=tf.estimator.ModeKeys.TRAIN,
                                            use_fake_data=FLAGS.use_fake_data)
    eval_input_fn = functools.partial(
        dataloader.InputReader(FLAGS.validation_file_pattern,
                               mode=tf.estimator.ModeKeys.PREDICT,
                               distributed_eval=True),
        num_examples=eval_steps * FLAGS.eval_batch_size)
    eval_metric = coco_metric.EvaluationMetric(FLAGS.val_json_file,
                                               use_cpp_extension=True)

    def init_fn():
        if FLAGS.resnet_checkpoint:
            tf.train.init_from_checkpoint(FLAGS.resnet_checkpoint,
                                          {'resnet/': 'resnet50/'})

    runner.initialize(train_input_fn,
                      eval_input_fn,
                      mask_rcnn_model.MaskRcnnModelFn(params),
                      FLAGS.train_batch_size,
                      FLAGS.eval_batch_size,
                      FLAGS.input_partition_dims,
                      init_fn,
                      params=params)
    mlp_log.mlperf_print('init_stop', None)
    mlp_log.mlperf_print('run_start', None)

    def eval_init_fn(cur_step):
        """Executed before every eval."""
        steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
        cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch
        mlp_log.mlperf_print('block_start',
                             None,
                             metadata={
                                 'first_epoch_num': cur_epoch,
                                 'epoch_count': 1
                             })

    def eval_finish_fn(cur_step, eval_output, _):
        """Callback function that's executed after each eval."""
        if eval_steps == 0:
            return False
        # Concat eval_output as eval_output is a list from each host.
        for key in eval_output:
            eval_output[key] = np.concatenate(eval_output[key], axis=0)
        steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
        cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch
        mlp_log.mlperf_print('block_stop',
                             None,
                             metadata={
                                 'first_epoch_num': cur_epoch,
                                 'epoch_count': 1
                             })
        eval_multiprocess.eval_multiprocessing(
            eval_output, eval_metric, mask_rcnn_params.EVAL_WORKER_COUNT)

        mlp_log.mlperf_print('eval_start',
                             None,
                             metadata={'epoch_num': cur_epoch + 1})
        _, eval_results = eval_metric.evaluate()
        mlp_log.mlperf_print('eval_accuracy', {
            'BBOX': float(eval_results['AP']),
            'SEGM': float(eval_results['mask_AP'])
        },
                             metadata={'epoch_num': cur_epoch + 1})
        mlp_log.mlperf_print('eval_stop',
                             None,
                             metadata={'epoch_num': cur_epoch + 1})
        if (eval_results['AP'] >= mask_rcnn_params.BOX_EVAL_TARGET and
                eval_results['mask_AP'] >= mask_rcnn_params.MASK_EVAL_TARGET):
            mlp_log.mlperf_print('run_stop',
                                 None,
                                 metadata={'status': 'success'})
            return True
        return False

    def run_finish_fn(success):
        if not success:
            mlp_log.mlperf_print('run_stop',
                                 None,
                                 metadata={'status': 'abort'})

    runner.train_and_eval(eval_init_fn, eval_finish_fn, run_finish_fn)
예제 #19
0
def run_main(flags, default_hparams, estimator_fn):
  """Run main."""
  # Job
  jobid = flags.jobid
  utils.print_out("# Job id %d" % jobid)

  # Random
  random_seed = flags.random_seed
  if random_seed is not None and random_seed > 0:
    utils.print_out("# Set random seed to %d" % random_seed)
    random.seed(random_seed + jobid)
    np.random.seed(random_seed + jobid)
    tf.set_random_seed(random_seed)

  mlp_log.mlperf_print("cache_clear", True)
  mlp_log.mlperf_print("init_start", None)
  mlp_log.mlperf_print("submission_benchmark", "resnet")
  mlp_log.mlperf_print("submission_division", "closed")
  mlp_log.mlperf_print("submission_org", "google")
  mlp_log.mlperf_print("submission_platform", "tpu-v3-%d" % FLAGS.num_shards)
  mlp_log.mlperf_print("submission_status", "research")

  mlp_log.mlperf_print("global_batch_size", FLAGS.batch_size)
  mlp_log.mlperf_print("opt_learning_rate_alt_decay_func", "True")
  mlp_log.mlperf_print("opt_base_learning_rate", FLAGS.learning_rate)
  mlp_log.mlperf_print("opt_learning_rate_decay_interval", FLAGS.decay_interval)
  mlp_log.mlperf_print("opt_learning_rate_decay_factor", FLAGS.decay_factor)
  mlp_log.mlperf_print("opt_learning_rate_decay_steps", FLAGS.decay_steps)
  mlp_log.mlperf_print("opt_learning_rate_remain_steps", FLAGS.decay_start)
  mlp_log.mlperf_print("opt_learning_rate_alt_warmup_func", FLAGS.warmup_scheme)
  mlp_log.mlperf_print("opt_learning_rate_warmup_steps", FLAGS.warmup_steps)
  mlp_log.mlperf_print(
      "max_sequence_length", FLAGS.src_max_len, metadata={"method": "discard"})
  mlp_log.mlperf_print("train_samples", FLAGS.num_examples_per_epoch)
  mlp_log.mlperf_print("eval_samples", FLAGS.examples_to_infer)

  # Model output directory
  out_dir = flags.out_dir
  if out_dir and not tf.gfile.Exists(out_dir):
    utils.print_out("# Creating output directory %s ..." % out_dir)
    tf.gfile.MakeDirs(out_dir)

  # Load hparams.
  hparams = create_or_load_hparams(default_hparams, flags.hparams_path)

  # Train or Evaluation
  return estimator_fn(hparams)
예제 #20
0
def main(unused_argv):
    """Run the reinforcement learning loop."""
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('[%(asctime)s] %(message)s',
                                  '%Y-%m-%d %H:%M:%S')

    # ML Perf Logging.

    mlp_log.mlperf_print('cache_clear', True)
    mlp_log.mlperf_print('init_start', None)

    mlp_log.mlperf_print(key='train_batch_size',
                         value=FLAGS.training_batch_size)
    mlp_log.mlperf_print(key='filter_amount', value=FLAGS.filter_amount)
    mlp_log.mlperf_print(key='window_size', value=FLAGS.window_size)
    mlp_log.mlperf_print(key='lr_boundaries',
                         value=str(FLAGS.lr_boundaries).strip('[]'))
    mlp_log.mlperf_print(key='lr_rates', value=str(FLAGS.lr_rates).strip('[]'))

    mlp_log.mlperf_print(key='opt_weight_decay', value=FLAGS.l2_strength)
    mlp_log.mlperf_print(key='min_selfplay_games_per_generation',
                         value=FLAGS.mlperf_num_games)
    mlp_log.mlperf_print(key='train_samples', value=FLAGS.mlperf_num_games)
    mlp_log.mlperf_print(key='eval_samples', value=FLAGS.mlperf_num_games)
    mlp_log.mlperf_print(key='num_readouts', value=FLAGS.mlperf_num_readouts)
    mlp_log.mlperf_print(key='value_init_penalty',
                         value=FLAGS.mlperf_value_init_penalty)
    mlp_log.mlperf_print(key='holdout_pct', value=FLAGS.mlperf_holdout_pct)
    mlp_log.mlperf_print(key='disable_resign_pct',
                         value=FLAGS.mlperf_disable_resign_pct)
    mlp_log.mlperf_print(key='resign_threshold',
                         value=(sum(FLAGS.mlperf_resign_threshold) /
                                len(FLAGS.mlperf_resign_threshold)))
    mlp_log.mlperf_print(key='parallel_games',
                         value=FLAGS.mlperf_parallel_games)
    mlp_log.mlperf_print(key='virtual_losses',
                         value=FLAGS.mlperf_virtual_losses)
    mlp_log.mlperf_print(key='gating_win_rate',
                         value=FLAGS.mlperf_gating_win_rate)
    mlp_log.mlperf_print(key='eval_games', value=FLAGS.mlperf_eval_games)

    for handler in logger.handlers:
        handler.setFormatter(formatter)

    # The training loop must be bootstrapped; either by running bootstrap.sh
    # to generate training data from random games, or by running
    # copy_checkpoint.sh to copy an already generated checkpoint.
    model_dirs = list_selfplay_dirs(FLAGS.selfplay_dir)

    iteration_model_names = []
    if not model_dirs:
        raise RuntimeError(
            'Couldn\'t find any selfplay games under %s. Either bootstrap.sh '
            'or init_from_checkpoint.sh must be run before the train loop is '
            'started')
    model_num = int(os.path.basename(model_dirs[0]))
    tpu_name = FLAGS.tpu_name.split(':')[0]
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=True)
    timeout_run_options = tf.RunOptions(
        timeout_in_ms=FLAGS.worker_reset_timeout_ms)

    mlp_log.mlperf_print('init_stop', None)
    mlp_log.mlperf_print('run_start', None)
    with minigo_utils.logged_timer('Total time'):
        state = State(model_num)
        while state.iter_num < FLAGS.iterations:
            state.iter_num += 1
            iteration_model_names.append(state.train_model_name)
            mlp_log.mlperf_print(key='epoch_start',
                                 value=None,
                                 metadata={'epoch_num': state.iter_num})
            train_once(state)
            mlp_log.mlperf_print(key='epoch_stop',
                                 value=None,
                                 metadata={'epoch_num': state.iter_num})
            mlp_log.mlperf_print(key='save_model',
                                 value='{iteration_num: ' +
                                 str(state.iter_num) + ' }')

            # In the case where iterations are fast, TPUEstimator can deadlock
            # between iterations on TPU Init. We attempt to manually make sure
            # the worker can Init with deadlines so we don't get stuck.
            while True:
                try:
                    tf.logging.info('Attempting to shutdown worker.')
                    gc.collect()
                    with tf.Graph().as_default():
                        with tf.Session(tpu_name,
                                        config=session_config) as sess:
                            sess.run(tf.tpu.shutdown_system(job='tpu_worker'),
                                     options=timeout_run_options)
                    tf.logging.info('Attempting to initialize worker.')
                    with tf.Graph().as_default():
                        with tf.Session(tpu_name,
                                        config=session_config) as sess:
                            init_result = sess.run(
                                tf.tpu.initialize_system(job='tpu_worker'),
                                options=timeout_run_options)
                    if init_result:
                        tf.logging.info('Worker reset.')
                        break
                except tf.errors.DeadlineExceededError:
                    pass
    with tf.gfile.GFile(FLAGS.abort_file_path, 'w') as f:
        f.write('abort')

    total_file_count = 0
    for iteration_model_name in iteration_model_names:
        total_file_count = total_file_count + len(
            tf.io.gfile.glob(FLAGS.selfplay_dir + '/' + iteration_model_name +
                             '/*/*/*'))

    mlp_log.mlperf_print(key='actual_selfplay_games_per_generation',
                         value=int(total_file_count /
                                   len(iteration_model_names)))
 def run_finish_fn(success):
     if not success:
         mlp_log.mlperf_print("run_stop",
                              None,
                              metadata={"status": "abort"})
     mlp_log.mlperf_print("run_final", None)
예제 #22
0
def run_model(params,
              eval_init_fn=None,
              eval_finish_fn=None,
              run_finish_fn=None):
    """Run the DLRM model, using a pre-defined configuration.

  Args:
    params: HPTuner object that provides new params for the trial.
    eval_init_fn: Lambda to run at start of eval. None means use the default.
    eval_finish_fn: Lambda for end of eval. None means use the default.
    run_finish_fn: Lambda for end of execution. None means use the default.

  Returns:
    A list of tuples, each entry describing the eval metric for one eval. Each
    tuple entry is (global_step, metric_value).
  """
    mlp_log.mlperf_print(key="cache_clear", value=True)
    mlp_log.mlperf_print(key="init_start", value=None)
    mlp_log.mlperf_print("global_batch_size", params["batch_size"])
    mlp_log.mlperf_print("train_samples", _NUM_TRAIN_EXAMPLES)
    mlp_log.mlperf_print("eval_samples", _NUM_EVAL_EXAMPLES)
    mlp_log.mlperf_print("opt_base_learning_rate", params["learning_rate"])
    mlp_log.mlperf_print("sgd_opt_base_learning_rate", params["learning_rate"])
    mlp_log.mlperf_print("sgd_opt_learning_rate_decay_poly_power", 2)
    mlp_log.mlperf_print("sgd_opt_learning_rate_decay_steps",
                         params["decay_steps"])
    mlp_log.mlperf_print("lr_decay_start_steps", params["decay_start_step"])
    mlp_log.mlperf_print("opt_learning_rate_warmup_steps",
                         params["lr_warmup_steps"])

    # Used for vizier. List of tuples. Each entry is (global_step, auc_metric).
    eval_metrics = [(0, 0.0)]

    feature_config = fc.FeatureConfig(params)
    (feature_to_config_dict,
     table_to_config_dict) = feature_config.get_feature_tbl_config()
    opt_params = {
        "sgd":
        tpu_embedding.StochasticGradientDescentParameters(
            learning_rate=params["learning_rate"]),
        "adagrad":
        tpu_embedding.AdagradParameters(
            learning_rate=params["learning_rate"],
            initial_accumulator=params["adagrad_init_accum"])
    }
    embedding = tpu_embedding.TPUEmbedding(
        table_to_config_dict,
        feature_to_config_dict,
        params["batch_size"],
        mode=tpu_embedding.TRAINING,
        optimization_parameters=opt_params[params["optimizer"]],
        partition_strategy="mod",
        pipeline_execution_with_tensor_core=FLAGS.pipeline_execution,
        master=FLAGS.master)

    runner = dlrm_embedding_runner.DLRMEmbeddingRunner(
        iterations_per_loop=FLAGS.steps_between_evals,
        train_steps=FLAGS.train_steps,
        eval_steps=FLAGS.eval_steps,
        num_replicas=FLAGS.num_tpu_shards,
        sparse_features_key="cat-features",
        embedding=embedding)

    train_input_fn, eval_input_fn = get_input_fns(params, feature_config)

    runner.initialize(train_input_fn,
                      eval_input_fn,
                      functools.partial(dlrm.dlrm_llr_model_fn, params,
                                        feature_config),
                      params["batch_size"],
                      params["eval_batch_size"],
                      train_has_labels=False,
                      eval_has_labels=False)

    mlp_log.mlperf_print("init_stop", None)
    mlp_log.mlperf_print("run_start", None)

    def _default_eval_init_fn(cur_step):
        """Logging statements executed before every eval."""
        eval_num = cur_step // FLAGS.steps_between_evals
        tf.logging.info("== Block {}. Step {} of {}".format(
            eval_num + 1, cur_step, FLAGS.train_steps))
        mlp_log.mlperf_print("block_start",
                             None,
                             metadata={
                                 "first_epoch_num": eval_num + 1,
                                 "epoch_count": 1
                             })
        mlp_log.mlperf_print("eval_start",
                             None,
                             metadata={"epoch_num": eval_num + 1})

    def _default_eval_finish_fn(cur_step, eval_output, summary_writer=None):
        eval_num = cur_step // FLAGS.steps_between_evals
        mlp_log.mlperf_print("eval_stop",
                             None,
                             metadata={"epoch_num": eval_num + 1})
        mlp_log.mlperf_print("block_stop",
                             None,
                             metadata={"first_epoch_num": eval_num + 1})
        tf.logging.info(
            "== Eval finished (step {}). Computing metric..".format(cur_step))

        results_np = np.array(eval_output["results"])
        results_np = np.reshape(results_np, (-1, 2))
        predictions_np = results_np[:, 0].astype(np.float32)
        targets_np = results_np[:, 1].astype(np.int32)
        roc_obj = roc_metrics.RocMetrics(predictions_np, targets_np)
        roc_auc = roc_obj.ComputeRocAuc()
        tf.logging.info("== Eval shape: {}.  AUC = {:.4f}".format(
            predictions_np.shape, roc_auc))
        success = roc_auc >= _ACCURACY_THRESH
        mlp_log.mlperf_print("eval_accuracy",
                             roc_auc,
                             metadata={"epoch_num": eval_num + 1})
        if success:
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "success"})
        if summary_writer:
            summary_writer.add_summary(
                utils.create_scalar_summary("auc", roc_auc),
                global_step=cur_step + FLAGS.steps_between_evals)
        eval_metrics.append((cur_step + FLAGS.steps_between_evals, roc_auc))
        return success

    def _default_run_finish_fn(success_status):
        if not success_status:
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "failure"})
        tf.logging.info("Retrieving embedding vars and writing stats.")
        runner.retrieve_embedding_vars()

    runner.train_and_eval(eval_init_fn=eval_init_fn or _default_eval_init_fn,
                          eval_finish_fn=eval_finish_fn
                          or _default_eval_finish_fn,
                          run_finish_fn=run_finish_fn
                          or _default_run_finish_fn)

    return eval_metrics
def run_pretraining(hparams):
    """Run pretraining with given hyperparameters."""

    global masked_lm_accuracy
    global run_steps

    masked_lm_accuracy = 0
    run_steps = 0

    def eval_init_fn(cur_step):
        """Executed before every eval."""
        # While BERT pretraining does not have epochs,
        # to make the logging consistent with other mlperf models,
        # in all the mlp_log, epochs are steps, and examples are sequences.
        mlp_log.mlperf_print("block_start",
                             None,
                             metadata={
                                 "first_epoch_num":
                                 cur_step + FLAGS.iterations_per_loop,
                                 "epoch_count": FLAGS.iterations_per_loop
                             })

    def eval_finish_fn(cur_step, eval_output, summary_writer):
        """Executed after every eval."""
        global run_steps
        global masked_lm_accuracy
        cur_step_corrected = cur_step + FLAGS.iterations_per_loop
        run_steps = cur_step_corrected
        masked_lm_weighted_correct = eval_output["masked_lm_weighted_correct"]
        masked_lm_weighted_count = eval_output["masked_lm_weighted_count"]

        masked_lm_accuracy = np.sum(masked_lm_weighted_correct) / np.sum(
            masked_lm_weighted_count)
        # the eval_output may mix up the order of the two arrays
        # swap the order if it did got mix up
        if masked_lm_accuracy > 1:
            masked_lm_accuracy = 1 / masked_lm_accuracy

        if summary_writer:
            with tf.Graph().as_default():
                summary_writer.add_summary(
                    tf.Summary(value=[
                        tf.Summary.Value(tag="masked_lm_accuracy",
                                         simple_value=masked_lm_accuracy)
                    ]), cur_step_corrected)

        mlp_log.mlperf_print("block_stop",
                             None,
                             metadata={
                                 "first_epoch_num": cur_step_corrected,
                             })
        # While BERT pretraining does not have epochs,
        # to make the logging consistent with other mlperf models,
        # in all the mlp_log, epochs are steps, and examples are sequences.
        mlp_log.mlperf_print("eval_accuracy",
                             float(masked_lm_accuracy),
                             metadata={"epoch_num": cur_step_corrected})
        if (masked_lm_accuracy >= FLAGS.stop_threshold
                and cur_step_corrected >= FLAGS.iterations_per_loop * 6):
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "success"})
            return True
        else:
            return False

    def run_finish_fn(success):
        if not success:
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "abort"})
        mlp_log.mlperf_print("run_final", None)

    def init_fn():
        if FLAGS.init_checkpoint:
            tf.train.init_from_checkpoint(FLAGS.init_checkpoint, {
                "bert/": "bert/",
                "cls/": "cls/",
            })

    # Passing the hyperparameters
    if "learning_rate" in hparams:
        FLAGS.learning_rate = hparams.learning_rate
    if "lamb_weight_decay_rate" in hparams:
        FLAGS.lamb_weight_decay_rate = hparams.lamb_weight_decay_rate
    if "lamb_beta_1" in hparams:
        FLAGS.lamb_beta_1 = hparams.lamb_beta_1
    if "lamb_beta_2" in hparams:
        FLAGS.lamb_beta_2 = hparams.lamb_beta_2
    if "epsilon" in hparams:
        FLAGS.epsilon = hparams.epsilon
    if "num_warmup_steps" in hparams:
        FLAGS.num_warmup_steps = hparams.num_warmup_steps
    if "num_train_steps" in hparams:
        FLAGS.num_train_steps = hparams.num_train_steps

    # Input handling
    tf.logging.set_verbosity(tf.logging.INFO)
    if FLAGS.repeatable:
        tf.set_random_seed(123)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    input_files = []
    for input_pattern in FLAGS.input_file.split(","):
        input_files.extend(tf.gfile.Glob(input_pattern))

    tf.logging.info("*** Input Files ***")
    tf.logging.info("%s Files." % len(input_files))

    dataset_train = dataset_input.input_fn_builder(
        input_files=input_files,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        is_training=True)

    dataset_eval = dataset_input.input_fn_builder(
        input_files=input_files,
        max_seq_length=FLAGS.max_seq_length,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        is_training=False,
        num_eval_samples=FLAGS.num_eval_samples)

    # Create the low level runner
    low_level_runner = train_and_eval_runner.TrainAndEvalRunner(
        FLAGS.iterations_per_loop, FLAGS.stop_steps + 1, FLAGS.max_eval_steps,
        FLAGS.num_tpu_cores // FLAGS.num_partitions)

    mlp_log.mlperf_print("cache_clear", True)
    mlp_log.mlperf_print("init_start", None)
    mlp_log.mlperf_print("global_batch_size", FLAGS.train_batch_size)
    mlp_log.mlperf_print("opt_learning_rate_warmup_steps",
                         FLAGS.num_warmup_steps)
    mlp_log.mlperf_print("num_warmup_steps", FLAGS.num_warmup_steps)
    mlp_log.mlperf_print("start_warmup_step", FLAGS.start_warmup_step)
    mlp_log.mlperf_print("max_sequence_length", FLAGS.max_seq_length)
    mlp_log.mlperf_print("opt_base_learning_rate", FLAGS.learning_rate)
    mlp_log.mlperf_print("opt_lamb_beta_1", FLAGS.lamb_beta_1)
    mlp_log.mlperf_print("opt_lamb_beta_2", FLAGS.lamb_beta_2)
    mlp_log.mlperf_print("opt_epsilon", 10**FLAGS.log_epsilon)
    mlp_log.mlperf_print("opt_learning_rate_training_steps",
                         FLAGS.num_train_steps)
    mlp_log.mlperf_print("opt_lamb_weight_decay_rate",
                         FLAGS.lamb_weight_decay_rate)
    mlp_log.mlperf_print("opt_lamb_learning_rate_decay_poly_power", 1)
    mlp_log.mlperf_print("opt_gradient_accumulation_steps", 0)
    mlp_log.mlperf_print("max_predictions_per_seq",
                         FLAGS.max_predictions_per_seq)

    low_level_runner.initialize(dataset_train,
                                dataset_eval,
                                bert_model_fn,
                                FLAGS.train_batch_size,
                                FLAGS.eval_batch_size,
                                input_partition_dims=None,
                                init_fn=init_fn,
                                train_has_labels=False,
                                eval_has_labels=False,
                                num_partitions=FLAGS.num_partitions)

    mlp_log.mlperf_print("init_stop", None)

    mlp_log.mlperf_print("run_start", None)

    # To make the logging consistent with other mlperf models,
    # in all the mlp_log, epochs are steps, and examples are sequences.
    mlp_log.mlperf_print("train_samples",
                         FLAGS.num_train_steps * FLAGS.train_batch_size)
    mlp_log.mlperf_print("eval_samples",
                         FLAGS.max_eval_steps * FLAGS.eval_batch_size)
    low_level_runner.train_and_eval(eval_init_fn, eval_finish_fn,
                                    run_finish_fn)
    return masked_lm_accuracy, run_steps
예제 #24
0
 def run_finish_fn(success):
   if not success:
     mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'})
   mlp_log.mlperf_print('run_final', None)
예제 #25
0
def resnet_model_fn(features, labels, is_training):
  """The model_fn for ResNet to be used with TPU.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    is_training: whether this is training

  Returns:
    train_op, logits
  """
  if isinstance(features, dict):
    features = features['feature']

  if FLAGS.use_space_to_depth:
    if FLAGS.train_batch_size // FLAGS.num_replicas > 8:
      features = tf.reshape(
          features, [FLAGS.image_size // 2, FLAGS.image_size // 2, 12, -1])
      features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC
    else:
      features = tf.reshape(
          features, [FLAGS.image_size // 2, FLAGS.image_size // 2, -1, 12])
      features = tf.transpose(features, [2, 0, 1, 3])  # HWNC to NHWC
  else:
    if FLAGS.train_batch_size // FLAGS.num_replicas > 8:
      features = tf.reshape(features,
                            [FLAGS.image_size, FLAGS.image_size, 3, -1])
      features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC
    else:
      features = tf.reshape(features,
                            [FLAGS.image_size, FLAGS.image_size, -1, 3])
      features = tf.transpose(features, [2, 0, 1, 3])  # HWCN to NHWC

  # Normalize the image to zero mean and unit variance.
  if FLAGS.use_space_to_depth:
    features -= tf.constant(MEAN_RGB, shape=[1, 1, 12], dtype=features.dtype)
    features /= tf.constant(STDDEV_RGB, shape=[1, 1, 12], dtype=features.dtype)
  else:
    features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
    features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

  # This nested function allows us to avoid duplicating the logic which
  # builds the network, for different values of --precision.
  def build_network():
    with tf.variable_scope('resnet', reuse=tf.AUTO_REUSE):
      network = resnet_model.resnet_v1(
          resnet_depth=FLAGS.resnet_depth,
          num_classes=FLAGS.num_label_classes,
          use_space_to_depth=FLAGS.use_space_to_depth,
          num_replicas=FLAGS.num_replicas,
          distributed_group_size=FLAGS.distributed_group_size)
      return network(inputs=features, is_training=is_training)

  if FLAGS.precision == 'bfloat16':
    with tf.tpu.bfloat16_scope():
      logits = build_network()
    logits = tf.cast(logits, tf.float32)
  elif FLAGS.precision == 'float32':
    logits = build_network()

  if not is_training:
    total_correct = tf.reduce_sum(
        tf.cast(
            tf.equal(tf.cast(tf.argmax(logits, axis=1), labels.dtype), labels),
            tf.int32))
    return None, {'total_correct': tf.reshape(total_correct, [-1])}

  # Calculate loss, which includes softmax cross entropy and L2 regularization.
  one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)
  cross_entropy = tf.losses.softmax_cross_entropy(
      logits=logits,
      onehot_labels=one_hot_labels,
      label_smoothing=FLAGS.label_smoothing)

  # Add weight decay to the loss for non-batch-normalization variables.
  if FLAGS.enable_lars:
    loss = cross_entropy
  else:
    loss = cross_entropy + FLAGS.weight_decay * tf.add_n([
        tf.nn.l2_loss(v)
        for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

  global_step = tf.train.get_or_create_global_step()
  steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
  current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)

  mlp_log.mlperf_print(
      'model_bn_span',
      FLAGS.distributed_group_size *
      (FLAGS.train_batch_size // FLAGS.num_replicas))

  if FLAGS.enable_lars:
    learning_rate = 0.0
    mlp_log.mlperf_print('opt_name', 'lars')
    optimizer = lars_util.init_lars_optimizer(current_epoch)
  else:
    mlp_log.mlperf_print('opt_name', 'sgd')
    learning_rate = learning_rate_schedule(current_epoch)
    optimizer = tf.train.MomentumOptimizer(
        learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True)
  optimizer = tf.tpu.CrossShardOptimizer(optimizer)

  # Batch normalization requires UPDATE_OPS to be added as a dependency to
  # the train operation.
  with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    train_op = optimizer.minimize(loss, global_step)
  return train_op, None