예제 #1
0
def learning_rate_schedule(current_epoch):
    """Handles linear scaling rule, gradual warmup, and LR decay.

  The learning rate starts at 0, then it increases linearly per step.
  After 5 epochs we reach the base learning rate (scaled to account
    for batch size).
  After 30, 60 and 80 epochs the learning rate is divided by 10.
  After 90 epochs training stops and the LR is set to 0. This ensures
    that we train for exactly 90 epochs for reproducibility.

  Args:
    current_epoch: `Tensor` for current epoch.

  Returns:
    A scaled `Tensor` for current learning rate.
  """
    mlp_log.mlperf_print('lars_opt_base_learning_rate',
                         FLAGS.base_learning_rate)
    scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)

    decay_rate = (scaled_lr * LR_SCHEDULE[0][0] * current_epoch /
                  LR_SCHEDULE[0][1])
    for mult, start_epoch in LR_SCHEDULE:
        decay_rate = tf.where(current_epoch < start_epoch, decay_rate,
                              scaled_lr * mult)
    return decay_rate
예제 #2
0
 def _default_run_finish_fn(success_status):
     if not success_status:
         mlp_log.mlperf_print("run_stop",
                              None,
                              metadata={"status": "failure"})
     tf.logging.info("Retrieving embedding vars and writing stats.")
     runner.retrieve_embedding_vars()
예제 #3
0
  def eval_finish_fn(cur_step, eval_output, _):
    """Callback function that's executed after each eval."""
    if eval_steps == 0:
      return False
    # Concat eval_output as eval_output is a list from each host.
    for key in eval_output:
      eval_output[key] = np.concatenate(eval_output[key], axis=0)
    steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
    cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch
    mlp_log.mlperf_print(
        'block_stop',
        None,
        metadata={
            'first_epoch_num': cur_epoch,
            'epoch_count': 1
        })
    eval_multiprocess.eval_multiprocessing(eval_output, eval_metric,
                                           mask_rcnn_params.EVAL_WORKER_COUNT)

    mlp_log.mlperf_print(
        'eval_start', None, metadata={'epoch_num': cur_epoch + 1})
    _, eval_results = eval_metric.evaluate()
    mlp_log.mlperf_print(
        'eval_accuracy',
        {'BBOX': float(eval_results['AP']),
         'SEGM': float(eval_results['mask_AP'])},
        metadata={'epoch_num': cur_epoch + 1})
    mlp_log.mlperf_print(
        'eval_stop', None, metadata={'epoch_num': cur_epoch + 1})
    if (eval_results['AP'] >= mask_rcnn_params.BOX_EVAL_TARGET and
        eval_results['mask_AP'] >= mask_rcnn_params.MASK_EVAL_TARGET):
      mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'})
      return True
    return False
예제 #4
0
def learning_rate_schedule(params, global_step):
    """Handles learning rate scaling, linear warmup, and learning rate decay.

  Args:
    params: A dictionary that defines hyperparameters of model.
    global_step: A tensor representing current global step.

  Returns:
    A tensor representing current learning rate.
  """
    base_learning_rate = params['base_learning_rate']
    lr_warmup_step = params['lr_warmup_step']
    first_lr_drop_step = params['first_lr_drop_step']
    second_lr_drop_step = params['second_lr_drop_step']
    batch_size = params['batch_size'] * params['num_shards']
    scaling_factor = batch_size / ssd_constants.DEFAULT_BATCH_SIZE
    mlp_log.mlperf_print('opt_learning_rate_warmup_factor', scaling_factor)
    mlp_log.mlperf_print('opt_learning_rate_warmup_steps', lr_warmup_step)
    adjusted_learning_rate = base_learning_rate * scaling_factor
    learning_rate = (tf.cast(global_step, dtype=tf.float32) /
                     lr_warmup_step) * adjusted_learning_rate
    lr_schedule = [[1.0, lr_warmup_step], [0.1, first_lr_drop_step],
                   [0.01, second_lr_drop_step]]
    for mult, start_global_step in lr_schedule:
        learning_rate = tf.where(global_step < start_global_step,
                                 learning_rate, adjusted_learning_rate * mult)
    return learning_rate
예제 #5
0
def poly_rate_schedule(current_epoch, poly_rate=0.0):
    """Handles linear scaling rule, gradual warmup, and LR decay.

  The learning rate starts at 0, then it increases linearly per step.  After
  FLAGS.poly_warmup_epochs, we reach the base learning rate (scaled to account
  for batch size). The learning rate is then decayed using a polynomial rate
  decay schedule with power 2.0.

  Args:
    current_epoch: `Tensor` for current epoch.
    poly_rate: Polynomial decay rate.

  Returns:
    A scaled `Tensor` for current learning rate.
  """

    batch_size = FLAGS.train_batch_size
    if batch_size < 16384:
        plr = 10.0
        w_epochs = 5
    elif batch_size < 32768:
        plr = 25.0
        w_epochs = 5
    else:
        plr = 31.2
        w_epochs = 25

    # Override default poly learning rate and warmup epochs
    if poly_rate > 0.0:
        plr = poly_rate

    if FLAGS.lars_base_learning_rate > 0.0:
        plr = FLAGS.lars_base_learning_rate

    if FLAGS.lars_warmup_epochs > 0:
        w_epochs = FLAGS.lars_warmup_epochs

    mlp_log.mlperf_print('lars_opt_base_learning_rate', plr)
    mlp_log.mlperf_print('lars_opt_learning_rate_warmup_epochs', w_epochs)
    end_lr = 0.0001
    mlp_log.mlperf_print('lars_opt_end_learning_rate', end_lr)

    wrate = (plr * current_epoch / w_epochs)
    w_steps = (w_epochs * FLAGS.num_train_images // batch_size)
    min_step = tf.constant(1, dtype=tf.int64)
    global_step = tf.train.get_or_create_global_step()
    decay_steps = tf.maximum(min_step, tf.subtract(global_step, w_steps))

    mlp_log.mlperf_print('lars_opt_learning_rate_decay_steps',
                         FLAGS.train_steps - w_steps + 1)
    mlp_log.mlperf_print('lars_opt_learning_rate_decay_poly_power', 2.0)

    poly_rate = tf.train.polynomial_decay(plr,
                                          decay_steps,
                                          FLAGS.train_steps - w_steps + 1,
                                          end_lr,
                                          power=2.0)
    decay_rate = tf.where(current_epoch <= w_epochs, wrate, poly_rate)
    return decay_rate
예제 #6
0
 def eval_init_fn(cur_step):
     """Executed before every eval."""
     steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
     epoch = cur_step // steps_per_epoch
     mlp_log.mlperf_print('block_start',
                          None,
                          metadata={
                              'first_epoch_num': epoch,
                              'epoch_count': 4
                          })
예제 #7
0
 def eval_init_fn(cur_step):
   """Executed before every eval."""
   steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
   cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch
   mlp_log.mlperf_print(
       'block_start',
       None,
       metadata={
           'first_epoch_num': cur_epoch,
           'epoch_count': 1
       })
 def eval_init_fn(cur_step):
   """Executed before every eval."""
   # While BERT pretraining does not have epochs,
   # to make the logging consistent with other mlperf models,
   # in all the mlp_log, epochs are steps, and examples are sequences.
   mlp_log.mlperf_print(
       "block_start",
       None,
       metadata={
           "first_epoch_num": cur_step + FLAGS.iterations_per_loop,
           "epoch_count": FLAGS.iterations_per_loop
       })
예제 #9
0
 def eval_finish_fn(cur_step, eval_output, _):
     steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
     epoch = cur_step // steps_per_epoch
     mlp_log.mlperf_print('block_stop',
                          None,
                          metadata={
                              'first_epoch_num':
                              epoch,
                              'epoch_count':
                              FLAGS.iterations_per_loop // steps_per_epoch
                          })
     if FLAGS.run_cocoeval:
         q_in.put((cur_step, eval_output['detections']))
예제 #10
0
def init_lars_optimizer(current_epoch):
    """Initialize the LARS Optimizer."""

    lars_epsilon = FLAGS.lars_epsilon
    mlp_log.mlperf_print('lars_epsilon', lars_epsilon)

    learning_rate = poly_rate_schedule(current_epoch, FLAGS.poly_rate)
    optimizer = contrib_opt.LARSOptimizer(
        learning_rate,
        momentum=FLAGS.momentum,
        weight_decay=FLAGS.weight_decay,
        skip_list=['batch_normalization', 'bias'],
        epsilon=lars_epsilon)
    return optimizer
예제 #11
0
 def log_eval_results_fn():
     """Print out MLPerf log."""
     result = q_out.get()
     success = False
     while result[0] != _STOP:
         if not success:
             steps_per_epoch = (FLAGS.num_examples_per_epoch //
                                FLAGS.train_batch_size)
             epoch = (result[0] +
                      FLAGS.iterations_per_loop) // steps_per_epoch
             mlp_log.mlperf_print('eval_accuracy',
                                  result[1]['COCO/AP'],
                                  metadata={'epoch_num': epoch})
             mlp_log.mlperf_print('eval_stop',
                                  None,
                                  metadata={'epoch_num': epoch})
             if result[1]['COCO/AP'] > ssd_constants.EVAL_TARGET:
                 success = True
                 mlp_log.mlperf_print('run_stop',
                                      None,
                                      metadata={'status': 'success'})
         result = q_out.get()
     if not success:
         mlp_log.mlperf_print('run_stop',
                              None,
                              metadata={'status': 'abort'})
예제 #12
0
    def _default_eval_finish_fn(cur_step, eval_output, summary_writer=None):
        eval_num = cur_step // FLAGS.steps_between_evals
        mlp_log.mlperf_print("eval_stop",
                             None,
                             metadata={"epoch_num": eval_num + 1})
        mlp_log.mlperf_print("block_stop",
                             None,
                             metadata={"first_epoch_num": eval_num + 1})
        tf.logging.info(
            "== Eval finished (step {}). Computing metric..".format(cur_step))

        results_np = np.array(eval_output["results"])
        results_np = np.reshape(results_np, (-1, 2))
        predictions_np = results_np[:, 0].astype(np.float32)
        targets_np = results_np[:, 1].astype(np.int32)
        roc_obj = roc_metrics.RocMetrics(predictions_np, targets_np)
        roc_auc = roc_obj.ComputeRocAuc()
        tf.logging.info("== Eval shape: {}.  AUC = {:.4f}".format(
            predictions_np.shape, roc_auc))
        success = roc_auc >= _ACCURACY_THRESH
        mlp_log.mlperf_print("eval_accuracy",
                             roc_auc,
                             metadata={"epoch_num": eval_num + 1})
        if success:
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "success"})
        if summary_writer:
            summary_writer.add_summary(
                utils.create_scalar_summary("auc", roc_auc),
                global_step=cur_step + FLAGS.steps_between_evals)
        eval_metrics.append((cur_step + FLAGS.steps_between_evals, roc_auc))
        return success
예제 #13
0
 def _default_eval_init_fn(cur_step):
     """Logging statements executed before every eval."""
     eval_num = cur_step // FLAGS.steps_between_evals
     tf.logging.info("== Block {}. Step {} of {}".format(
         eval_num + 1, cur_step, FLAGS.train_steps))
     mlp_log.mlperf_print("block_start",
                          None,
                          metadata={
                              "first_epoch_num": eval_num + 1,
                              "epoch_count": 1
                          })
     mlp_log.mlperf_print("eval_start",
                          None,
                          metadata={"epoch_num": eval_num + 1})
예제 #14
0
 def eval_init_fn(cur_step):
     """Executed before every eval."""
     steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
     epoch = cur_step // steps_per_epoch
     mlp_log.mlperf_print('block_start',
                          None,
                          metadata={
                              'first_epoch_num':
                              epoch,
                              'epoch_count':
                              FLAGS.iterations_per_loop // steps_per_epoch
                          })
     mlp_log.mlperf_print('eval_start',
                          None,
                          metadata={
                              'epoch_num':
                              epoch +
                              FLAGS.iterations_per_loop // steps_per_epoch
                          })
예제 #15
0
def _write_metrics(eval_metrics, train_metrics, host_step, total_training_steps,
                   host_id):
  """Logs the accuracy metrics."""
  del host_id
  global RUN_STOP
  global TOTAL_STEPS
  if RUN_STOP:
    return

  eval_metrics = jax.tree_map(jax.device_get, eval_metrics)
  train_metrics = jax.tree_map(jax.device_get, train_metrics)

  masked_lm_accuracy = (
      np.sum(eval_metrics['masked_lm_weighted_correct']) /
      np.sum(eval_metrics['masked_lm_weighted_count']))
  total_loss = np.mean(train_metrics['total_loss'])
  lm_loss = np.mean(train_metrics['lm_loss'])
  sentence_loss = np.mean(train_metrics['sentence_loss'])

  mlp_log.mlperf_print('eval_accuracy', float(masked_lm_accuracy),
                       metadata={'epoch_num': host_step})

  logging.info('(Step %s / %s), masked_lm_accuracy: %s', host_step,
               total_training_steps, masked_lm_accuracy)
  logging.info(
      '(----Step %s / %s) Total loss: %s | LM loss: %s | Sentence loss: %s',
      host_step, total_training_steps, total_loss, lm_loss, sentence_loss)
  mlp_log.mlperf_print('eval_stop', None, metadata={'epoch_num': host_step})

  if masked_lm_accuracy >= FLAGS.target_accuracy:
    mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'})
    RUN_STOP = time.time()
    TOTAL_STEPS = host_step
예제 #16
0
    def eval_finish_fn(cur_step, eval_output, summary_writer):
        """Executed after every eval."""
        steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
        epoch = cur_step // steps_per_epoch
        eval_accuracy = float(np.sum(
            eval_output['total_correct'])) / FLAGS.num_eval_images

        if summary_writer:
            with tf.Graph().as_default():
                summary_writer.add_summary(
                    tf.Summary(value=[
                        tf.Summary.Value(tag='accuracy',
                                         simple_value=eval_accuracy)
                    ]), cur_step)
        mlp_log.mlperf_print('eval_accuracy',
                             eval_accuracy,
                             metadata={
                                 'epoch_num':
                                 epoch +
                                 FLAGS.iterations_per_loop // steps_per_epoch
                             })
        mlp_log.mlperf_print('block_stop',
                             None,
                             metadata={
                                 'first_epoch_num': epoch,
                                 'epoch_count': 4
                             })
        if eval_accuracy >= FLAGS.stop_threshold:
            mlp_log.mlperf_print('run_stop',
                                 None,
                                 metadata={'status': 'success'})
            return True
        else:
            return False
 def get_learning_rate(self, global_step):
     """Sets up learning rate schedule."""
     learning_rate = lr_policy.learning_rate_schedule(
         self.params['learning_rate'], self.params['lr_warmup_init'],
         self.params['lr_warmup_step'], self.params['first_lr_drop_step'],
         self.params['second_lr_drop_step'], global_step)
     mlp_log.mlperf_print(key='opt_base_learning_rate',
                          value=self.params['learning_rate'])
     mlp_log.mlperf_print(key='opt_learning_rate_warmup_steps',
                          value=self.params['lr_warmup_step'])
     mlp_log.mlperf_print(key='opt_learning_rate_warmup_factor',
                          value=self.params['learning_rate'] /
                          self.params['lr_warmup_step'])
     return learning_rate
        def infeed_thread_fn(sess, train_enqueue_ops, eval_enqueue_ops,
                             eval_init):
            """Start the infeed."""
            time.sleep(150)

            mlp_log.mlperf_print("init_stop", None)
            mlp_log.mlperf_print("run_start", None)
            mlp_log.mlperf_print("block_start",
                                 None,
                                 metadata={
                                     "first_epoch_num": 1,
                                     "epoch_count": 1
                                 })

            for i in range(self.hparams.max_train_epochs):
                tf.logging.info("Infeed for epoch: %d", i + 1)
                sess.run(eval_init)
                sess.run([train_enqueue_ops])
                sess.run([eval_enqueue_ops])
  def eval_finish_fn(cur_step, eval_output, summary_writer):
    """Executed after every eval."""
    global run_steps
    global masked_lm_accuracy
    cur_step_corrected = cur_step + FLAGS.iterations_per_loop
    run_steps = cur_step_corrected
    masked_lm_weighted_correct = eval_output["masked_lm_weighted_correct"]
    masked_lm_weighted_count = eval_output["masked_lm_weighted_count"]

    masked_lm_accuracy = np.sum(masked_lm_weighted_correct) / np.sum(
        masked_lm_weighted_count)
    # the eval_output may mix up the order of the two arrays
    # swap the order if it did got mix up
    if masked_lm_accuracy > 1:
      masked_lm_accuracy = 1 / masked_lm_accuracy

    if summary_writer:
      with tf.Graph().as_default():
        summary_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="masked_lm_accuracy",
                                 simple_value=masked_lm_accuracy)
            ]), cur_step_corrected)

    mlp_log.mlperf_print(
        "block_stop",
        None,
        metadata={
            "first_epoch_num": cur_step_corrected,
        })
    # While BERT pretraining does not have epochs,
    # to make the logging consistent with other mlperf models,
    # in all the mlp_log, epochs are steps, and examples are sequences.
    mlp_log.mlperf_print(
        "eval_accuracy",
        float(masked_lm_accuracy),
        metadata={"epoch_num": cur_step_corrected})
    if (masked_lm_accuracy >= FLAGS.stop_threshold and
        cur_step_corrected >= FLAGS.iterations_per_loop * 6):
      mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"})
      return True
    else:
      return False
    def train_and_predict(self):
        """Run the predict loop on the TPU device."""
        self.sess.run([self.compile_op])

        # Train and eval thread.
        def train_eval_thread_fn(sess, train_eval_op):
            tf.logging.info("train_eval_op start")
            sess.run([train_eval_op])

        train_eval_thread = threading.Thread(target=train_eval_thread_fn,
                                             args=(self.sess,
                                                   self.train_eval_op))
        train_eval_thread.start()

        # Infeed thread.
        def infeed_thread_fn(sess, train_enqueue_ops, eval_enqueue_ops,
                             eval_init):
            """Start the infeed."""
            time.sleep(150)

            mlp_log.mlperf_print("init_stop", None)
            mlp_log.mlperf_print("run_start", None)
            mlp_log.mlperf_print("block_start",
                                 None,
                                 metadata={
                                     "first_epoch_num": 1,
                                     "epoch_count": 1
                                 })

            for i in range(self.hparams.max_train_epochs):
                tf.logging.info("Infeed for epoch: %d", i + 1)
                sess.run(eval_init)
                sess.run([train_enqueue_ops])
                sess.run([eval_enqueue_ops])

        infeed_thread = threading.Thread(target=infeed_thread_fn,
                                         args=(self.sess, self.enqueue_ops,
                                               self.eval_enqueue_ops,
                                               self.eval_dataset_initializer))
        infeed_thread.start()

        if self.eval_steps > 0:
            eval_state = {"run_success": False, "score": 0.0}

            for epoch in range(self.hparams.max_train_epochs):
                predictions = list(self.predict())
                mlp_log.mlperf_print("eval_start",
                                     None,
                                     metadata={"epoch_num": epoch + 1})
                current_step = epoch * self.iterations

                eval_state["score"] = metric.get_metric(
                    self.hparams, predictions, current_step)
                tf.logging.info("Score after epoch %d: %f", epoch,
                                eval_state["score"])
                mlp_log.mlperf_print("eval_accuracy",
                                     eval_state["score"] / 100.0,
                                     metadata={"epoch_num": epoch + 1})
                mlp_log.mlperf_print("eval_stop",
                                     None,
                                     metadata={"epoch_num": epoch + 1})
                mlp_log.mlperf_print("block_stop",
                                     None,
                                     metadata={
                                         "first_epoch_num": epoch + 1,
                                         "epoch_count": 1
                                     })
                if eval_state["score"] >= self.hparams.target_bleu:
                    eval_state["run_success"] = True
                    mlp_log.mlperf_print("run_stop",
                                         None,
                                         metadata={"status": "success"})
                    break
                mlp_log.mlperf_print("block_start",
                                     None,
                                     metadata={
                                         "first_epoch_num": epoch + 2,
                                         "epoch_count": 1
                                     })

            if not eval_state["run_success"]:
                mlp_log.mlperf_print("run_stop",
                                     None,
                                     metadata={"status": "abort"})

        infeed_thread.join()
        train_eval_thread.join()

        if self.eval_steps > 0:
            return eval_state["score"], current_step
        else:
            return None, None
예제 #21
0
def main(argv):
  del argv  # Unused.

  # TODO(b/132208296): remove this workaround that uses control flow v2.
  control_flow_util.ENABLE_CONTROL_FLOW_V2 = True

  # Parse hparams
  hparams = mask_rcnn_params.default_hparams()
  hparams.parse(FLAGS.hparams)

  params = dict(
      hparams.values(),
      transpose_input=False if FLAGS.input_partition_dims is not None else True,
      resnet_checkpoint=FLAGS.resnet_checkpoint,
      val_json_file=FLAGS.val_json_file,
      num_cores_per_replica=int(np.prod(FLAGS.input_partition_dims))
      if FLAGS.input_partition_dims else 1,
      replicas_per_host=FLAGS.replicas_per_host)

  # MLPerf logging.
  mlp_log.mlperf_print(key='cache_clear', value=True)
  mlp_log.mlperf_print(key='init_start', value=None)
  mlp_log.mlperf_print(key='global_batch_size', value=FLAGS.train_batch_size)
  mlp_log.mlperf_print(key='train_samples', value=FLAGS.num_examples_per_epoch)
  mlp_log.mlperf_print(key='eval_samples', value=FLAGS.eval_samples)
  mlp_log.mlperf_print(
      key='min_image_size', value=params['short_side_image_size'])
  mlp_log.mlperf_print(
      key='max_image_size', value=params['long_side_max_image_size'])
  mlp_log.mlperf_print(key='num_image_candidates',
                       value=params['rpn_post_nms_topn'])

  train_steps = (
      FLAGS.num_epochs * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size)
  eval_steps = int(math.ceil(float(FLAGS.eval_samples) / FLAGS.eval_batch_size))
  if eval_steps > 0:
    # The eval dataset is not evenly divided. Adding step by one will make sure
    # all eval samples are covered.
    # TODO(b/151732586): regenerate the eval dataset to make all hosts get the
    #                    same amount of work.
    eval_steps += 1
  runner = train_and_eval_runner.TrainAndEvalRunner(
      FLAGS.num_examples_per_epoch // FLAGS.train_batch_size, train_steps,
      eval_steps, FLAGS.num_shards)
  train_input_fn = dataloader.InputReader(
      FLAGS.training_file_pattern,
      mode=tf.estimator.ModeKeys.TRAIN,
      use_fake_data=FLAGS.use_fake_data)
  eval_input_fn = functools.partial(
      dataloader.InputReader(
          FLAGS.validation_file_pattern,
          mode=tf.estimator.ModeKeys.PREDICT,
          distributed_eval=True),
      num_examples=eval_steps * FLAGS.eval_batch_size)
  eval_metric = coco_metric.EvaluationMetric(
      FLAGS.val_json_file, use_cpp_extension=True)

  def init_fn():
    if FLAGS.resnet_checkpoint:
      tf.train.init_from_checkpoint(FLAGS.resnet_checkpoint,
                                    {'resnet/': 'resnet50/'})

  runner.initialize(train_input_fn, eval_input_fn,
                    mask_rcnn_model.MaskRcnnModelFn(params),
                    FLAGS.train_batch_size, FLAGS.eval_batch_size,
                    FLAGS.input_partition_dims, init_fn, params=params)
  mlp_log.mlperf_print('init_stop', None)
  mlp_log.mlperf_print('run_start', None)

  def eval_init_fn(cur_step):
    """Executed before every eval."""
    steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
    cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch
    mlp_log.mlperf_print(
        'block_start',
        None,
        metadata={
            'first_epoch_num': cur_epoch,
            'epoch_count': 1
        })

  def eval_finish_fn(cur_step, eval_output, _):
    """Callback function that's executed after each eval."""
    if eval_steps == 0:
      return False
    # Concat eval_output as eval_output is a list from each host.
    for key in eval_output:
      eval_output[key] = np.concatenate(eval_output[key], axis=0)
    steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
    cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch
    mlp_log.mlperf_print(
        'block_stop',
        None,
        metadata={
            'first_epoch_num': cur_epoch,
            'epoch_count': 1
        })
    eval_multiprocess.eval_multiprocessing(eval_output, eval_metric,
                                           mask_rcnn_params.EVAL_WORKER_COUNT)

    mlp_log.mlperf_print(
        'eval_start', None, metadata={'epoch_num': cur_epoch + 1})
    _, eval_results = eval_metric.evaluate()
    mlp_log.mlperf_print(
        'eval_accuracy',
        {'BBOX': float(eval_results['AP']),
         'SEGM': float(eval_results['mask_AP'])},
        metadata={'epoch_num': cur_epoch + 1})
    mlp_log.mlperf_print(
        'eval_stop', None, metadata={'epoch_num': cur_epoch + 1})
    if (eval_results['AP'] >= mask_rcnn_params.BOX_EVAL_TARGET and
        eval_results['mask_AP'] >= mask_rcnn_params.MASK_EVAL_TARGET):
      mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'})
      return True
    return False

  def run_finish_fn(success):
    if not success:
      mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'})

  runner.train_and_eval(eval_init_fn, eval_finish_fn, run_finish_fn)
예제 #22
0
 def run_finish_fn(success):
   if not success:
     mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'})
예제 #23
0
def resnet_model_fn(features, labels, is_training):
    """The model_fn for ResNet to be used with TPU.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    is_training: whether this is training

  Returns:
    train_op, logits
  """
    if isinstance(features, dict):
        features = features['feature']

    if FLAGS.use_space_to_depth:
        if FLAGS.train_batch_size // FLAGS.num_replicas > 8:
            features = tf.reshape(
                features,
                [FLAGS.image_size // 2, FLAGS.image_size // 2, 12, -1])
            features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC
        else:
            features = tf.reshape(
                features,
                [FLAGS.image_size // 2, FLAGS.image_size // 2, -1, 12])
            features = tf.transpose(features, [2, 0, 1, 3])  # HWNC to NHWC
    else:
        if FLAGS.train_batch_size // FLAGS.num_replicas > 8:
            features = tf.reshape(features,
                                  [FLAGS.image_size, FLAGS.image_size, 3, -1])
            features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC
        else:
            features = tf.reshape(features,
                                  [FLAGS.image_size, FLAGS.image_size, -1, 3])
            features = tf.transpose(features, [2, 0, 1, 3])  # HWCN to NHWC

    # Normalize the image to zero mean and unit variance.
    if FLAGS.use_space_to_depth:
        features -= tf.constant(MEAN_RGB,
                                shape=[1, 1, 12],
                                dtype=features.dtype)
        features /= tf.constant(STDDEV_RGB,
                                shape=[1, 1, 12],
                                dtype=features.dtype)
    else:
        features -= tf.constant(MEAN_RGB,
                                shape=[1, 1, 3],
                                dtype=features.dtype)
        features /= tf.constant(STDDEV_RGB,
                                shape=[1, 1, 3],
                                dtype=features.dtype)

    # This nested function allows us to avoid duplicating the logic which
    # builds the network, for different values of --precision.
    def build_network():
        with tf.variable_scope('resnet', reuse=tf.AUTO_REUSE):
            network = resnet_model.resnet_v1(
                resnet_depth=FLAGS.resnet_depth,
                num_classes=FLAGS.num_label_classes,
                use_space_to_depth=FLAGS.use_space_to_depth,
                num_replicas=FLAGS.num_replicas,
                distributed_group_size=FLAGS.distributed_group_size)
            return network(inputs=features, is_training=is_training)

    if FLAGS.precision == 'bfloat16':
        with tf.tpu.bfloat16_scope():
            logits = build_network()
        logits = tf.cast(logits, tf.float32)
    elif FLAGS.precision == 'float32':
        logits = build_network()

    if not is_training:
        total_correct = tf.reduce_sum(
            tf.cast(
                tf.equal(tf.cast(tf.argmax(logits, axis=1), labels.dtype),
                         labels), tf.int32))
        return None, {'total_correct': tf.reshape(total_correct, [-1])}

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=one_hot_labels,
        label_smoothing=FLAGS.label_smoothing)

    # Add weight decay to the loss for non-batch-normalization variables.
    if FLAGS.enable_lars:
        loss = cross_entropy
    else:
        loss = cross_entropy + FLAGS.weight_decay * tf.add_n([
            tf.nn.l2_loss(v) for v in tf.trainable_variables()
            if 'batch_normalization' not in v.name
        ])

    global_step = tf.train.get_or_create_global_step()
    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)

    mlp_log.mlperf_print(
        'model_bn_span',
        FLAGS.distributed_group_size *
        (FLAGS.train_batch_size // FLAGS.num_replicas))

    if FLAGS.enable_lars:
        learning_rate = 0.0
        mlp_log.mlperf_print('opt_name', 'lars')
        optimizer = lars_util.init_lars_optimizer(current_epoch)
    else:
        mlp_log.mlperf_print('opt_name', 'sgd')
        learning_rate = learning_rate_schedule(current_epoch)
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=FLAGS.momentum,
                                               use_nesterov=True)
    optimizer = tf.tpu.CrossShardOptimizer(optimizer)

    # Batch normalization requires UPDATE_OPS to be added as a dependency to
    # the train operation.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss, global_step)
    return train_op, None
예제 #24
0
def run_model(params,
              eval_init_fn=None,
              eval_finish_fn=None,
              run_finish_fn=None):
    """Run the DLRM model, using a pre-defined configuration.

  Args:
    params: HPTuner object that provides new params for the trial.
    eval_init_fn: Lambda to run at start of eval. None means use the default.
    eval_finish_fn: Lambda for end of eval. None means use the default.
    run_finish_fn: Lambda for end of execution. None means use the default.

  Returns:
    A list of tuples, each entry describing the eval metric for one eval. Each
    tuple entry is (global_step, metric_value).
  """
    mlp_log.mlperf_print(key="cache_clear", value=True)
    mlp_log.mlperf_print(key="init_start", value=None)
    mlp_log.mlperf_print("global_batch_size", params["batch_size"])
    mlp_log.mlperf_print("train_samples", _NUM_TRAIN_EXAMPLES)
    mlp_log.mlperf_print("eval_samples", _NUM_EVAL_EXAMPLES)
    adjusted_lr = params["learning_rate"] * (params["batch_size"] / 2048.0)
    mlp_log.mlperf_print("opt_base_learning_rate", adjusted_lr)
    mlp_log.mlperf_print("sgd_opt_base_learning_rate", adjusted_lr)
    mlp_log.mlperf_print("sgd_opt_learning_rate_decay_poly_power", 2)
    mlp_log.mlperf_print("sgd_opt_learning_rate_decay_steps",
                         params["decay_steps"])
    mlp_log.mlperf_print("lr_decay_start_steps", params["decay_start_step"])
    mlp_log.mlperf_print("opt_learning_rate_warmup_steps",
                         params["lr_warmup_steps"])

    # Used for vizier. List of tuples. Each entry is (global_step, auc_metric).
    eval_metrics = [(0, 0.0)]

    feature_config = fc.FeatureConfig(params)
    (feature_to_config_dict,
     table_to_config_dict) = feature_config.get_feature_tbl_config()
    opt_params = {
        "sgd":
        tpu_embedding.StochasticGradientDescentParameters(
            learning_rate=params["learning_rate"]),
        "adagrad":
        tpu_embedding.AdagradParameters(
            learning_rate=params["learning_rate"],
            initial_accumulator=params["adagrad_init_accum"])
    }
    embedding = tpu_embedding.TPUEmbedding(
        table_to_config_dict,
        feature_to_config_dict,
        params["batch_size"],
        mode=tpu_embedding.TRAINING,
        optimization_parameters=opt_params[params["optimizer"]],
        partition_strategy="mod",
        pipeline_execution_with_tensor_core=FLAGS.pipeline_execution,
        master=FLAGS.master)

    runner = dlrm_embedding_runner.DLRMEmbeddingRunner(
        iterations_per_loop=FLAGS.steps_between_evals,
        train_steps=FLAGS.train_steps,
        eval_steps=FLAGS.eval_steps,
        num_replicas=FLAGS.num_tpu_shards,
        sparse_features_key="cat-features",
        embedding=embedding)

    train_input_fn, eval_input_fn = get_input_fns(params, feature_config)

    runner.initialize(train_input_fn,
                      eval_input_fn,
                      functools.partial(dlrm.dlrm_llr_model_fn, params,
                                        feature_config),
                      params["batch_size"],
                      params["eval_batch_size"],
                      train_has_labels=False,
                      eval_has_labels=False)

    mlp_log.mlperf_print("init_stop", None)
    mlp_log.mlperf_print("run_start", None)

    def _default_eval_init_fn(cur_step):
        """Logging statements executed before every eval."""
        eval_num = cur_step // FLAGS.steps_between_evals
        tf.logging.info("== Block {}. Step {} of {}".format(
            eval_num + 1, cur_step, FLAGS.train_steps))
        mlp_log.mlperf_print("block_start",
                             None,
                             metadata={
                                 "first_epoch_num": eval_num + 1,
                                 "epoch_count": 1
                             })
        mlp_log.mlperf_print("eval_start",
                             None,
                             metadata={"epoch_num": eval_num + 1})

    def _default_eval_finish_fn(cur_step, eval_output, summary_writer=None):
        eval_num = cur_step // FLAGS.steps_between_evals
        mlp_log.mlperf_print("eval_stop",
                             None,
                             metadata={"epoch_num": eval_num + 1})
        mlp_log.mlperf_print("block_stop",
                             None,
                             metadata={"first_epoch_num": eval_num + 1})
        tf.logging.info(
            "== Eval finished (step {}). Computing metric..".format(cur_step))

        results_np = np.array(eval_output["results"])
        results_np = np.reshape(results_np, (-1, 2))
        predictions_np = results_np[:, 0].astype(np.float32)
        targets_np = results_np[:, 1].astype(np.int32)
        roc_obj = roc_metrics.RocMetrics(predictions_np, targets_np)
        roc_auc = roc_obj.ComputeRocAuc()
        tf.logging.info("== Eval shape: {}.  AUC = {:.4f}".format(
            predictions_np.shape, roc_auc))
        success = roc_auc >= _ACCURACY_THRESH
        mlp_log.mlperf_print("eval_accuracy",
                             roc_auc,
                             metadata={"epoch_num": eval_num + 1})
        if success:
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "success"})
        if summary_writer:
            summary_writer.add_summary(
                utils.create_scalar_summary("auc", roc_auc),
                global_step=cur_step + FLAGS.steps_between_evals)
        eval_metrics.append((cur_step + FLAGS.steps_between_evals, roc_auc))
        return success

    def _default_run_finish_fn(success_status):
        if not success_status:
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "failure"})
        tf.logging.info("Retrieving embedding vars and writing stats.")
        runner.retrieve_embedding_vars()

    runner.train_and_eval(eval_init_fn=eval_init_fn or _default_eval_init_fn,
                          eval_finish_fn=eval_finish_fn
                          or _default_eval_finish_fn,
                          run_finish_fn=run_finish_fn
                          or _default_run_finish_fn)

    return eval_metrics
예제 #25
0
def run_pretrain(optimizer):
  """Run bert pretraining.

  Args:
    optimizer: BERT model with pretraining layer

  Returns:
    optimizer: trained model
  """
  result_stats = {}
  def get_input_context():

    class InputContext():

      def __init__(self):
        self.input_pipeline_id = jax.host_id()
        self.num_input_pipelines = jax.host_count()
    return InputContext()

  summary_thread = thread.ThreadPoolExecutor(1, 'summary')
  host_id = jax.host_id()
  # Get input dataset
  input_files = []
  for input_pattern in FLAGS.input_files.split(','):
    input_files.extend(tf.io.gfile.glob(input_pattern))
  logging.info('*** Input Files ***')
  for input_file in input_files:
    logging.info('  %s', input_file)

  eval_input_files = []
  for input_pattern in FLAGS.eval_input_files.split(','):
    eval_input_files.extend(tf.io.gfile.glob(input_pattern))
  logging.info('*** Eval Input Files ***')
  for input_file in eval_input_files:
    logging.info('  %s', input_file)

  train_input_fn = input_pipeline.input_fn_builder(
      input_files=input_files,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=FLAGS.max_predictions_per_seq,
      is_training=True,
      num_cpu_threads=8)

  host_train_batch_size = FLAGS.train_batch_size // jax.host_count()
  host_eval_batch_size = FLAGS.eval_batch_size // jax.host_count()

  params = {'batch_size': host_train_batch_size}
  input_context = get_input_context()
  train_dataset = train_input_fn(params, input_context)
  train_iterator = iter(train_dataset)

  eval_input_fn = input_pipeline.input_fn_builder(
      input_files=eval_input_files,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=FLAGS.max_predictions_per_seq,
      is_training=False,
      num_cpu_threads=8,
      global_input_size=FLAGS.eval_sample_size)
  eval_params = {'batch_size': host_eval_batch_size}
  eval_dataset = eval_input_fn(eval_params, input_context)
  eval_iterator = iter(eval_dataset)

  # train step
  total_training_steps = FLAGS.total_training_steps
  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=FLAGS.learning_rate,
      warmup_steps=FLAGS.warmup_steps,
      total_training_steps=FLAGS.total_training_steps,
      poly_power=FLAGS.poly_power,
      start_warmup_step=FLAGS.start_warmup_step)

  # Device training loop cond.
  def device_train_loop_cond(args):
    _, _, _, _, _, _, step, epoch, num_steps_per_epoch = args
    return step // num_steps_per_epoch == epoch

  # Device training loop body.
  def device_train_loop_body(args):
    """Device training loop body."""
    (optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng, token,
     step, epoch, num_steps_per_epoch) = args
    device_batch_size = FLAGS.train_batch_size // jax.device_count()
    input_shape = [device_batch_size, FLAGS.max_seq_length]
    input_shape_pred = [device_batch_size, FLAGS.max_predictions_per_seq]
    (input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids,
     masked_lm_weights, next_sentence_labels), token = lax.infeed(
         token,
         shape=(jax.ShapedArray(input_shape, jnp.int32),
                jax.ShapedArray(input_shape, jnp.int32),
                jax.ShapedArray(input_shape, jnp.int32),
                jax.ShapedArray(input_shape_pred, jnp.int32),
                jax.ShapedArray(input_shape_pred, jnp.int32),
                jax.ShapedArray(input_shape_pred, jnp.float32),
                jax.ShapedArray([device_batch_size, 1], jnp.int32)))
    inputs = [input_ids, input_mask, segment_ids, masked_lm_positions]
    labels = [masked_lm_ids, masked_lm_weights, next_sentence_labels]
    optimizer, total_loss, lm_loss, sentence_loss, new_dropout_rng = train_step(
        optimizer,
        inputs,
        labels,
        learning_rate_fn,
        dropout_rng=new_dropout_rng)
    step += 1
    return (optimizer, total_loss, lm_loss, sentence_loss,
            new_dropout_rng, token, step, epoch, num_steps_per_epoch)

  # Device training loop.
  def device_train_loop(optimizer, dropout_rng, total_loss, lm_loss,
                        sentence_loss, step, epoch, num_steps_per_epoch):
    """Device training loop."""
    token = lax.create_token(step)
    (optimizer, total_loss, lm_loss, sentence_loss, dropout_rng,
     _, step, epoch, num_steps_per_epoch) = lax.while_loop(
         device_train_loop_cond, device_train_loop_body,
         (optimizer, total_loss, lm_loss, sentence_loss, dropout_rng, token,
          step, epoch, num_steps_per_epoch))
    return optimizer, total_loss, lm_loss, sentence_loss, dropout_rng, step

  if FLAGS.infeed:
    pmap_fn = jax.pmap
    if FLAGS.enable_buffer_donation:
      pmap_fn = functools.partial(pmap_fn, donate_argnums=(0, 1))
    if FLAGS.enable_wus:
      pmap_fn = functools.partial(
          pmap_fn, in_axes=(None, 0, None, None, None, None, None, None))

    p_train_epoch = pmap_fn(device_train_loop, axis_name='batch')
  else:
    # without infeed.
    p_train_step = jax.pmap(
        functools.partial(train_step, learning_rate_fn=learning_rate_fn),
        axis_name='batch')

  if FLAGS.infeed:
    # Infeed is currently synchronous, so do it in a background thread too
    infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed')

  pmap_fn = jax.pmap
  # Weight update sharding is not implemented yet for host train loop.
  # Enable wus on eval only if device loop is used.
  if FLAGS.enable_wus and FLAGS.infeed:
    pmap_fn = functools.partial(pmap_fn, in_axes=(None, 0, 0))
  p_eval_step = pmap_fn(eval_step, axis_name='batch')

  rng = random.PRNGKey(0)
  device_count = jax.local_device_count()
  dropout_rngs = random.split(rng, device_count)
  num_steps_per_epoch = np.int32(FLAGS.num_steps_per_epoch)
  if FLAGS.precompile:
    if FLAGS.infeed:
      if FLAGS.enable_wus:
        total_loss = np.float32(0.0)
        lm_loss = np.float32(0.0)
        sentence_loss = np.float32(0.0)
        host_step = 0
        host_epoch = 1
        optimizer = unbroadcast(optimizer)
        # the device training loop condition will immediately be false
        optimizer, total_loss, lm_loss, sentence_loss, _, _ = p_train_epoch(
            optimizer, dropout_rngs, total_loss, lm_loss, sentence_loss,
            host_step, host_epoch, num_steps_per_epoch)
      else:
        total_loss = jax_utils.replicate(np.float32(0.0))
        lm_loss = jax_utils.replicate(np.float32(0.0))
        sentence_loss = jax_utils.replicate(np.float32(0.0))
        device_step = jax_utils.replicate(0)
        device_epoch = jax_utils.replicate(1)
        # the device training loop condition will immediately be false
        optimizer, total_loss, lm_loss, sentence_loss, _, _ = p_train_epoch(
            optimizer, dropout_rngs, total_loss, lm_loss, sentence_loss,
            device_step, device_epoch, jax_utils.replicate(num_steps_per_epoch))

    else:
      train_input_shape = (host_train_batch_size, FLAGS.max_seq_length)
      train_input_shape_pred = (host_train_batch_size,
                                FLAGS.max_predictions_per_seq)
      word_id_data = jax.random.randint(rng, train_input_shape, 0, 10)
      mask_data = jax.random.randint(rng, train_input_shape, 0, 1)
      type_id_data = jax.random.randint(rng, train_input_shape, 0, 3)
      lm_mask = jax.random.randint(rng, train_input_shape_pred, 0, 5)
      masked_lm_ids = jax.random.randint(rng, train_input_shape_pred, 0, 2)
      masked_lm_weights = jax.random.randint(rng, train_input_shape_pred, 1,
                                             1).astype(np.float32)
      next_sentence_labels = jax.random.randint(rng, (host_train_batch_size, 1),
                                                0, 1)

      labels = [masked_lm_ids, masked_lm_weights, next_sentence_labels]
      train_inputs = [word_id_data, mask_data, type_id_data, lm_mask]
      train_inputs = common_utils.shard(train_inputs)
      labels = common_utils.shard(labels)
      p_train_step(optimizer, train_inputs, labels, dropout_rng=dropout_rngs)

    eval_input_shape = (host_eval_batch_size, FLAGS.max_seq_length)
    eval_input_shape_pred = (host_eval_batch_size,
                             FLAGS.max_predictions_per_seq)
    word_id_data = jax.random.randint(rng, eval_input_shape, 0, 10)
    mask_data = jax.random.randint(rng, eval_input_shape, 0, 1)
    type_id_data = jax.random.randint(rng, eval_input_shape, 0, 3)
    lm_mask = jax.random.randint(rng, eval_input_shape_pred, 0, 5)
    masked_lm_ids = jax.random.randint(rng, eval_input_shape_pred, 0, 2)
    masked_lm_weights = jax.random.randint(
        rng, eval_input_shape_pred, 1, 1).astype(np.float32)
    next_sentence_labels = jax.random.randint(rng, (host_eval_batch_size, 1), 0,
                                              1)

    eval_inputs = {
        'input_ids': word_id_data,
        'input_mask': mask_data,
        'segment_ids': type_id_data,
        'masked_lm_positions': lm_mask,
        'masked_lm_ids': masked_lm_ids,
        'masked_lm_weights': masked_lm_weights,
        'next_sentence_labels': next_sentence_labels
    }

    eval_inputs = common_utils.shard(eval_inputs)
    metrics = empty_metrics()
    optimizer_target = optimizer.target
    # Weight update sharding is not implemented yet for host train loop.
    # Enable wus on eval only if device loop is used.
    if FLAGS.enable_wus and FLAGS.infeed:
      optimizer_target = unbroadcast(optimizer_target)
    metrics = p_eval_step(optimizer_target, eval_inputs, metrics)
    metrics = allreduce_metrics(metrics)
  metrics = empty_metrics()
  time.sleep(FLAGS.init_sleep)
  allreduce_metrics(metrics)['masked_lm_weighted_correct'].block_until_ready()
  mlp_log.mlperf_print('init_stop', None)
  mlp_log.mlperf_print('run_start', None)
  # To make the logging consistent with other mlperf models,
  # in all the mlp_log, epochs are steps, and examples are sequences.
  mlp_log.mlperf_print('train_samples',
                       FLAGS.total_training_steps * FLAGS.train_batch_size)
  mlp_log.mlperf_print('eval_samples', FLAGS.eval_sample_size)
  xprof = None
  run_start = time.time()
  global RUN_STOP
  global TOTAL_STEPS
  RUN_STOP = False
  TOTAL_STEPS = False

  if host_id == 0:
    if FLAGS.end_to_end_profile:
      xprof = xprof_session.XprofSession()
      xprof.start_session(device_name='REDACTED',
                          enable_python_tracer=True,
                          host_trace_level=2)
    elif FLAGS.profile:
      profile_with_xprof_on_background(start_after_sec=FLAGS.profile_latency,
                                       profile_time_sec=FLAGS.profile_duration)

  if FLAGS.infeed:
    h_total_loss = np.float32(0.0)
    h_lm_loss = np.float32(0.0)
    h_sentence_loss = np.float32(0.0)

    d_total_loss = jax_utils.replicate(np.float32(0.0))
    d_lm_loss = jax_utils.replicate(np.float32(0.0))
    d_sentence_loss = jax_utils.replicate(np.float32(0.0))

  host_step, device_step = 0, jax_utils.replicate(0)
  device_epoch = jax_utils.replicate(0)
  num_train_epochs = FLAGS.total_training_steps // FLAGS.num_steps_per_epoch
  steps_per_epoch = num_steps_per_epoch
  if num_train_epochs >= 6:
    # Merge the first 6 epochs, as we do not have to do eval.
    steps_per_epoch = np.int32(num_steps_per_epoch * 6)
  for host_epoch in range(num_train_epochs):
    block_step = host_step
    # While BERT pretraining does not have epochs,
    # to make the logging consistent with other mlperf models,
    # in all the mlp_log, epochs are steps, and examples are sequences.
    mlp_log.mlperf_print(
        'block_start',
        None,
        metadata={
            'first_epoch_num': block_step,
            'epoch_count': FLAGS.num_steps_per_epoch
        })

    if not (num_train_epochs >= 6 and
            host_epoch in (1, 2, 3, 4, 5)) and FLAGS.infeed:
      if FLAGS.enable_wus:
        optimizer = unbroadcast(optimizer)
        (optimizer, total_loss, lm_loss, sentence_loss, dropout_rngs,
         device_step) = p_train_epoch(optimizer, dropout_rngs,
                                      h_total_loss, h_lm_loss, h_sentence_loss,
                                      host_step, host_epoch, steps_per_epoch)
      else:
        device_epoch = jax_utils.replicate(host_epoch)
        device_steps_per_epoch = jax_utils.replicate(steps_per_epoch)

        (optimizer, total_loss, lm_loss, sentence_loss, dropout_rngs,
         device_step) = p_train_epoch(optimizer, dropout_rngs,
                                      d_total_loss, d_lm_loss, d_sentence_loss,
                                      device_step, device_epoch,
                                      device_steps_per_epoch)
    # After first epoch, reduce the steps per epoch back to normal number.
    steps_per_epoch = num_steps_per_epoch

    # Training for one epoch.
    while int(host_step // FLAGS.num_steps_per_epoch) == host_epoch:
      input_data = next(train_iterator)
      input_data = jax.tree_map(lambda x: x.numpy(), input_data)
      input_data = jax.tree_map(common_utils.shard, input_data)
      input_ids = input_data['input_ids']
      input_mask = input_data['input_mask']
      segment_ids = input_data['segment_ids']
      masked_lm_positions = input_data['masked_lm_positions']
      masked_lm_ids = input_data['masked_lm_ids']
      masked_lm_weights = input_data['masked_lm_weights']
      next_sentence_labels = input_data['next_sentence_labels']

      # Infeed data to infeed queue.
      if FLAGS.infeed:
        for i, device in enumerate(jax.local_devices()):
          infeed_pool.submit(
              partial(device.transfer_to_infeed,
                      (input_ids[i], input_mask[i], segment_ids[i],
                       masked_lm_positions[i], masked_lm_ids[i],
                       masked_lm_weights[i], next_sentence_labels[i])))
      else:
        inputs = [input_ids, input_mask, segment_ids, masked_lm_positions]
        labels = [masked_lm_ids, masked_lm_weights, next_sentence_labels]
        (optimizer, total_loss, lm_loss, sentence_loss, dropout_rngs
         ) = p_train_step(optimizer, inputs, labels, dropout_rng=dropout_rngs)
      host_step += 1

    mlp_log.mlperf_print('block_stop', None, metadata={
        'first_epoch_num': block_step,
        'epoch_count': FLAGS.num_steps_per_epoch
    })
    # No need to do eval in the first 5 epochs as it has to traverse min 3M
    # samples.
    if host_epoch < 5:
      continue
    if host_step % FLAGS.num_steps_per_epoch == 0:
      mlp_log.mlperf_print(
          'eval_start', None, metadata={'epoch_num': host_step})
      optimizer_target = optimizer.target
      if FLAGS.enable_wus and FLAGS.infeed:
        optimizer_target = unbroadcast(optimizer_target)
      metrics = empty_metrics()
      for _ in range(FLAGS.max_eval_steps):
        inputs = jax.tree_map(lambda x: x.numpy(), next(eval_iterator))
        inputs = jax.tree_map(common_utils.shard, inputs)
        # Weight update sharding is not implemented yet for host train loop.
        # Enable wus on eval only if device loop is used.
        metrics = p_eval_step(optimizer_target, inputs, metrics)
      metrics = allreduce_metrics(metrics)
      train_metrics = {'total_loss': total_loss, 'lm_loss': lm_loss,
                       'sentence_loss': sentence_loss}
      # masked_lm_accuracy = get_masked_lm_accuracy(metrics)
      summary_thread.submit(partial(
          _write_metrics, metrics, train_metrics,
          host_step, total_training_steps, host_id))
    if host_step % FLAGS.num_steps_per_epoch == 0 and FLAGS.save_checkpoint:
      if host_id == 0:
        checkpoints.save_checkpoint(
            FLAGS.model_dir, optimizer, host_step, prefix='checkpoint', keep=1)
  allreduce_metrics(metrics)['masked_lm_weighted_correct'].block_until_ready()
  summary_thread.shutdown()
  if not RUN_STOP:
    mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'})
  mlp_log.mlperf_print('run_final', None)

  if host_id == 0:
    if FLAGS.end_to_end_profile:
      xprof_url = xprof.end_session_and_get_url(tag='')
      logging.info('Xprof profile is at %s', xprof_url)


  if RUN_STOP:
    result_stats['total_time'] = RUN_STOP - run_start
    result_stats['total_steps'] = TOTAL_STEPS
  return optimizer, result_stats
예제 #26
0
    def run_exp():
      mlp_log.mlperf_print('cache_clear', None)
      mlp_log.mlperf_print('init_start', None)
      mlp_log.mlperf_print('global_batch_size', FLAGS.train_batch_size)
      mlp_log.mlperf_print('opt_learning_rate_warmup_steps', FLAGS.warmup_steps)
      mlp_log.mlperf_print('num_warmup_steps', FLAGS.warmup_steps)
      mlp_log.mlperf_print('start_warmup_step', FLAGS.start_warmup_step)
      mlp_log.mlperf_print('opt_lamb_weight_decay_rate', FLAGS.lamb_weight_decay)

      mlp_log.mlperf_print('max_sequence_length', FLAGS.max_seq_length)
      mlp_log.mlperf_print('opt_base_learning_rate', FLAGS.learning_rate)
      mlp_log.mlperf_print('opt_lamb_beta_1', FLAGS.lamb_beta_1)
      mlp_log.mlperf_print('opt_lamb_beta_2', FLAGS.lamb_beta_2)
      mlp_log.mlperf_print('opt_lamb_learning_rate_decay_poly_power', 1)
      mlp_log.mlperf_print('opt_gradient_accumulation_steps', 0)
      mlp_log.mlperf_print('max_predictions_per_seq',
                           FLAGS.max_predictions_per_seq)
      mlp_log.mlperf_print('opt_epsilon', 10**FLAGS.log_epsilon)
      mlp_log.mlperf_print('opt_learning_rate_training_steps',
                           FLAGS.total_training_steps)
      mlp_log.mlperf_print('submission_benchmark', 'bert')
      mlp_log.mlperf_print('submission_division', 'closed')
      mlp_log.mlperf_print('submission_org', 'google')
      mlp_log.mlperf_print('submission_platform',
                           'tpu-v3-%d' % jax.device_count())
      mlp_log.mlperf_print('submission_status', 'research')

      jax_model, model_kwargs = get_pretrain_model()
      optimizer = create_optimizer(jax_model, model_kwargs, learning_rate=None)
      _, result_stats = run_pretrain(optimizer)
      return result_stats
예제 #27
0
def main(argv):
    del argv  # Unused.

    params = construct_run_config(FLAGS.iterations_per_loop)
    mlp_log.mlperf_print(key='cache_clear', value=True)
    mlp_log.mlperf_print(key='init_start', value=None)
    mlp_log.mlperf_print('global_batch_size', FLAGS.train_batch_size)
    mlp_log.mlperf_print('opt_base_learning_rate',
                         params['base_learning_rate'])
    mlp_log.mlperf_print(
        'opt_learning_rate_decay_boundary_epochs',
        [params['first_lr_drop_epoch'], params['second_lr_drop_epoch']])
    mlp_log.mlperf_print('opt_weight_decay', params['weight_decay'])
    mlp_log.mlperf_print(
        'model_bn_span', FLAGS.train_batch_size // FLAGS.num_shards *
        params['distributed_group_size'])
    mlp_log.mlperf_print('max_samples', ssd_constants.NUM_CROP_PASSES)
    mlp_log.mlperf_print('train_samples', FLAGS.num_examples_per_epoch)
    mlp_log.mlperf_print('eval_samples', FLAGS.eval_samples)

    params['batch_size'] = FLAGS.train_batch_size // FLAGS.num_shards
    input_partition_dims = FLAGS.input_partition_dims
    train_steps = FLAGS.num_epochs * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
    eval_steps = int(math.ceil(FLAGS.eval_samples / FLAGS.eval_batch_size))
    runner = train_and_eval_runner.TrainAndEvalRunner(
        FLAGS.iterations_per_loop, train_steps, eval_steps, FLAGS.num_shards)

    train_input_fn = dataloader.SSDInputReader(
        FLAGS.training_file_pattern,
        params['transpose_input'],
        is_training=True,
        use_fake_data=FLAGS.use_fake_data,
        params=params)
    eval_input_fn = dataloader.SSDInputReader(
        FLAGS.validation_file_pattern,
        is_training=False,
        use_fake_data=FLAGS.use_fake_data,
        distributed_eval=True,
        count=eval_steps * FLAGS.eval_batch_size,
        params=params)

    def init_fn():
        tf.train.init_from_checkpoint(
            params['resnet_checkpoint'], {
                'resnet/': 'resnet%s/' % ssd_constants.RESNET_DEPTH,
            })

    runner.initialize(train_input_fn, eval_input_fn,
                      functools.partial(ssd_model.ssd_model_fn,
                                        params), FLAGS.train_batch_size,
                      FLAGS.eval_batch_size, input_partition_dims, init_fn)
    mlp_log.mlperf_print('init_stop', None)
    mlp_log.mlperf_print('run_start', None)

    if FLAGS.run_cocoeval:
        # copybara:strip_begin
        q_in, q_out = REDACTEDprocess.get_user_data()
        processes = [
            REDACTEDprocess.Process(target=REDACTED_predict_post_processing)
            for _ in range(4)
        ]
        # copybara:strip_end_and_replace_begin
        # q_in = multiprocessing.Queue(maxsize=ssd_constants.QUEUE_SIZE)
        # q_out = multiprocessing.Queue(maxsize=ssd_constants.QUEUE_SIZE)
        # processes = [
        #     multiprocessing.Process(
        #         target=predict_post_processing, args=(q_in, q_out))
        #     for _ in range(self.num_multiprocessing_workers)
        # ]
        # copybara:replace_end
        for p in processes:
            p.start()

        def log_eval_results_fn():
            """Print out MLPerf log."""
            result = q_out.get()
            success = False
            while result[0] != _STOP:
                if not success:
                    steps_per_epoch = (FLAGS.num_examples_per_epoch //
                                       FLAGS.train_batch_size)
                    epoch = (result[0] +
                             FLAGS.iterations_per_loop) // steps_per_epoch
                    mlp_log.mlperf_print('eval_accuracy',
                                         result[1]['COCO/AP'],
                                         metadata={'epoch_num': epoch})
                    mlp_log.mlperf_print('eval_stop',
                                         None,
                                         metadata={'epoch_num': epoch})
                    if result[1]['COCO/AP'] > ssd_constants.EVAL_TARGET:
                        success = True
                        mlp_log.mlperf_print('run_stop',
                                             None,
                                             metadata={'status': 'success'})
                result = q_out.get()
            if not success:
                mlp_log.mlperf_print('run_stop',
                                     None,
                                     metadata={'status': 'abort'})

        log_eval_result_thread = threading.Thread(target=log_eval_results_fn)
        log_eval_result_thread.start()

    def eval_init_fn(cur_step):
        """Executed before every eval."""
        steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
        epoch = cur_step // steps_per_epoch
        mlp_log.mlperf_print('block_start',
                             None,
                             metadata={
                                 'first_epoch_num':
                                 epoch,
                                 'epoch_count':
                                 FLAGS.iterations_per_loop // steps_per_epoch
                             })
        mlp_log.mlperf_print('eval_start',
                             None,
                             metadata={
                                 'epoch_num':
                                 epoch +
                                 FLAGS.iterations_per_loop // steps_per_epoch
                             })

    def eval_finish_fn(cur_step, eval_output, _):
        steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
        epoch = cur_step // steps_per_epoch
        mlp_log.mlperf_print('block_stop',
                             None,
                             metadata={
                                 'first_epoch_num':
                                 epoch,
                                 'epoch_count':
                                 FLAGS.iterations_per_loop // steps_per_epoch
                             })
        if FLAGS.run_cocoeval:
            q_in.put((cur_step, eval_output['detections']))

    runner.train_and_eval(eval_init_fn, eval_finish_fn)

    if FLAGS.run_cocoeval:
        for _ in processes:
            q_in.put((_STOP, None))

        for p in processes:
            try:
                p.join(timeout=10)
            except Exception:  #  pylint: disable=broad-except
                pass

        q_out.put((_STOP, None))
        log_eval_result_thread.join()

        # Clear out all the queues to avoid deadlock.
        while not q_out.empty():
            q_out.get()
        while not q_in.empty():
            q_in.get()
 def run_finish_fn(success):
   if not success:
     mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"})
   mlp_log.mlperf_print("run_final", None)
예제 #29
0
def main(unused_argv):
    def eval_init_fn(cur_step):
        """Executed before every eval."""
        steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
        epoch = cur_step // steps_per_epoch
        mlp_log.mlperf_print('block_start',
                             None,
                             metadata={
                                 'first_epoch_num': epoch,
                                 'epoch_count': 4
                             })

    def eval_finish_fn(cur_step, eval_output, summary_writer):
        """Executed after every eval."""
        steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
        epoch = cur_step // steps_per_epoch
        eval_accuracy = float(np.sum(
            eval_output['total_correct'])) / FLAGS.num_eval_images

        if summary_writer:
            with tf.Graph().as_default():
                summary_writer.add_summary(
                    tf.Summary(value=[
                        tf.Summary.Value(tag='accuracy',
                                         simple_value=eval_accuracy)
                    ]), cur_step)
        mlp_log.mlperf_print('eval_accuracy',
                             eval_accuracy,
                             metadata={
                                 'epoch_num':
                                 epoch +
                                 FLAGS.iterations_per_loop // steps_per_epoch
                             })
        mlp_log.mlperf_print('block_stop',
                             None,
                             metadata={
                                 'first_epoch_num': epoch,
                                 'epoch_count': 4
                             })
        if eval_accuracy >= FLAGS.stop_threshold:
            mlp_log.mlperf_print('run_stop',
                                 None,
                                 metadata={'status': 'success'})
            return True
        else:
            return False

    def run_finish_fn(success):
        if not success:
            mlp_log.mlperf_print('run_stop',
                                 None,
                                 metadata={'status': 'abort'})
        mlp_log.mlperf_print('run_final', None)

    low_level_runner = train_and_eval_runner.TrainAndEvalRunner(
        FLAGS.iterations_per_loop, FLAGS.train_steps,
        int(math.ceil(FLAGS.num_eval_images / FLAGS.eval_batch_size)),
        FLAGS.num_replicas)

    mlp_log.mlperf_print('cache_clear', True)
    mlp_log.mlperf_print('init_start', None)
    mlp_log.mlperf_print('global_batch_size', FLAGS.train_batch_size)
    mlp_log.mlperf_print('lars_opt_weight_decay', FLAGS.weight_decay)
    mlp_log.mlperf_print('lars_opt_momentum', FLAGS.momentum)
    mlp_log.mlperf_print('submission_benchmark', 'resnet')
    mlp_log.mlperf_print('submission_division', 'closed')
    mlp_log.mlperf_print('submission_org', 'google')
    mlp_log.mlperf_print('submission_platform',
                         'tpu-v3-%d' % FLAGS.num_replicas)
    mlp_log.mlperf_print('submission_status', 'research')

    assert FLAGS.precision == 'bfloat16' or FLAGS.precision == 'float32', (
        'Invalid value for --precision flag; must be bfloat16 or float32.')
    input_dtype = tf.bfloat16 if FLAGS.precision == 'bfloat16' else tf.float32
    cache_decoded_image = True if FLAGS.num_replicas > 2048 else False
    imagenet_train, imagenet_eval = [
        imagenet_input.get_input_fn(  # pylint: disable=g-complex-comprehension
            FLAGS.data_dir,
            is_training,
            input_dtype,
            FLAGS.image_size,
            FLAGS.input_partition_dims is None,
            cache_decoded_image=cache_decoded_image)
        for is_training in [True, False]
    ]

    low_level_runner.initialize(imagenet_train, imagenet_eval, resnet_model_fn,
                                FLAGS.train_batch_size, FLAGS.eval_batch_size,
                                FLAGS.input_partition_dims)

    mlp_log.mlperf_print('train_samples', FLAGS.num_train_images)
    mlp_log.mlperf_print('eval_samples', FLAGS.num_eval_images)
    mlp_log.mlperf_print('init_stop', None)
    mlp_log.mlperf_print('run_start', None)
    low_level_runner.train_and_eval(eval_init_fn, eval_finish_fn,
                                    run_finish_fn)
def run_pretraining(hparams):
  """Run pretraining with given hyperparameters."""

  global masked_lm_accuracy
  global run_steps

  masked_lm_accuracy = 0
  run_steps = 0

  def eval_init_fn(cur_step):
    """Executed before every eval."""
    # While BERT pretraining does not have epochs,
    # to make the logging consistent with other mlperf models,
    # in all the mlp_log, epochs are steps, and examples are sequences.
    mlp_log.mlperf_print(
        "block_start",
        None,
        metadata={
            "first_epoch_num": cur_step + FLAGS.iterations_per_loop,
            "epoch_count": FLAGS.iterations_per_loop
        })

  def eval_finish_fn(cur_step, eval_output, summary_writer):
    """Executed after every eval."""
    global run_steps
    global masked_lm_accuracy
    cur_step_corrected = cur_step + FLAGS.iterations_per_loop
    run_steps = cur_step_corrected
    masked_lm_weighted_correct = eval_output["masked_lm_weighted_correct"]
    masked_lm_weighted_count = eval_output["masked_lm_weighted_count"]

    masked_lm_accuracy = np.sum(masked_lm_weighted_correct) / np.sum(
        masked_lm_weighted_count)
    # the eval_output may mix up the order of the two arrays
    # swap the order if it did got mix up
    if masked_lm_accuracy > 1:
      masked_lm_accuracy = 1 / masked_lm_accuracy

    if summary_writer:
      with tf.Graph().as_default():
        summary_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="masked_lm_accuracy",
                                 simple_value=masked_lm_accuracy)
            ]), cur_step_corrected)

    mlp_log.mlperf_print(
        "block_stop",
        None,
        metadata={
            "first_epoch_num": cur_step_corrected,
        })
    # While BERT pretraining does not have epochs,
    # to make the logging consistent with other mlperf models,
    # in all the mlp_log, epochs are steps, and examples are sequences.
    mlp_log.mlperf_print(
        "eval_accuracy",
        float(masked_lm_accuracy),
        metadata={"epoch_num": cur_step_corrected})
    if (masked_lm_accuracy >= FLAGS.stop_threshold and
        cur_step_corrected >= FLAGS.iterations_per_loop * 6):
      mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"})
      return True
    else:
      return False

  def run_finish_fn(success):
    if not success:
      mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"})
    mlp_log.mlperf_print("run_final", None)

  def init_fn():
    if FLAGS.init_checkpoint:
      tf.train.init_from_checkpoint(FLAGS.init_checkpoint, {
          "bert/": "bert/",
          "cls/": "cls/",
      })

  # Passing the hyperparameters
  if "learning_rate" in hparams:
    FLAGS.learning_rate = hparams.learning_rate
  if "lamb_weight_decay_rate" in hparams:
    FLAGS.lamb_weight_decay_rate = hparams.lamb_weight_decay_rate
  if "lamb_beta_1" in hparams:
    FLAGS.lamb_beta_1 = hparams.lamb_beta_1
  if "lamb_beta_2" in hparams:
    FLAGS.lamb_beta_2 = hparams.lamb_beta_2
  if "epsilon" in hparams:
    FLAGS.epsilon = hparams.epsilon
  if "num_warmup_steps" in hparams:
    FLAGS.num_warmup_steps = hparams.num_warmup_steps
  if "num_train_steps" in hparams:
    FLAGS.num_train_steps = hparams.num_train_steps

  # Input handling
  tf.logging.set_verbosity(tf.logging.INFO)
  if FLAGS.repeatable:
    tf.set_random_seed(123)

  if not FLAGS.do_train and not FLAGS.do_eval:
    raise ValueError("At least one of `do_train` or `do_eval` must be True.")

  train_input_files = []
  for input_pattern in FLAGS.input_file.split(","):
    train_input_files.extend(tf.gfile.Glob(input_pattern))

  eval_input_file = "/REDACTED/je-d/home/staging-REDACTED-gpu-dedicated/bert/eval_original_dataset/part-*"
  eval_input_files = []
  for input_pattern in eval_input_file.split(","):
    eval_input_files.extend(tf.gfile.Glob(input_pattern))

  tf.logging.info("*** Input Files ***")
  tf.logging.info("%s Files." % len(train_input_files))

  dataset_train = dataset_input.input_fn_builder(
      input_files=train_input_files,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=FLAGS.max_predictions_per_seq,
      is_training=True,
      num_cpu_threads=8)

  dataset_eval = dataset_input.input_fn_builder(
      input_files=eval_input_files,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=FLAGS.max_predictions_per_seq,
      is_training=False,
      num_cpu_threads=8,
      num_eval_samples=FLAGS.num_eval_samples)

  # Create the low level runner
  low_level_runner = train_and_eval_runner.TrainAndEvalRunner(
      FLAGS.iterations_per_loop, FLAGS.stop_steps + 1, FLAGS.max_eval_steps,
      FLAGS.num_tpu_cores // FLAGS.num_partitions)

  mlp_log.mlperf_print("cache_clear", True)
  mlp_log.mlperf_print("init_start", None)
  mlp_log.mlperf_print("global_batch_size", FLAGS.train_batch_size)
  mlp_log.mlperf_print("opt_learning_rate_warmup_steps", FLAGS.num_warmup_steps)
  mlp_log.mlperf_print("num_warmup_steps", FLAGS.num_warmup_steps)
  mlp_log.mlperf_print("start_warmup_step", FLAGS.start_warmup_step)
  mlp_log.mlperf_print("max_sequence_length", FLAGS.max_seq_length)
  mlp_log.mlperf_print("opt_base_learning_rate", FLAGS.learning_rate)
  mlp_log.mlperf_print("opt_lamb_beta_1", FLAGS.lamb_beta_1)
  mlp_log.mlperf_print("opt_lamb_beta_2", FLAGS.lamb_beta_2)
  mlp_log.mlperf_print("opt_epsilon", 10 ** FLAGS.log_epsilon)
  mlp_log.mlperf_print("opt_learning_rate_training_steps",
                       FLAGS.num_train_steps)
  mlp_log.mlperf_print("opt_lamb_weight_decay_rate",
                       FLAGS.lamb_weight_decay_rate)
  mlp_log.mlperf_print("opt_lamb_learning_rate_decay_poly_power", 1)
  mlp_log.mlperf_print("opt_gradient_accumulation_steps", 0)
  mlp_log.mlperf_print("max_predictions_per_seq", FLAGS.max_predictions_per_seq)

  low_level_runner.initialize(
      dataset_train,
      dataset_eval,
      bert_model_fn,
      FLAGS.train_batch_size,
      FLAGS.eval_batch_size,
      input_partition_dims=None,
      init_fn=init_fn,
      train_has_labels=False,
      eval_has_labels=False,
      num_partitions=FLAGS.num_partitions)

  mlp_log.mlperf_print("init_stop", None)

  mlp_log.mlperf_print("run_start", None)

  # To make the logging consistent with other mlperf models,
  # in all the mlp_log, epochs are steps, and examples are sequences.
  mlp_log.mlperf_print("train_samples",
                       FLAGS.num_train_steps * FLAGS.train_batch_size)
  mlp_log.mlperf_print("eval_samples", FLAGS.num_eval_samples)
  low_level_runner.train_and_eval(eval_init_fn, eval_finish_fn, run_finish_fn)
  return masked_lm_accuracy, run_steps