コード例 #1
0
    def __init__(self):
        fk = ssd_constants.IMAGE_SIZE / np.array(ssd_constants.STEPS)

        self.default_boxes = []
        # size of feature and number of feature
        for idx, feature_size in enumerate(ssd_constants.FEATURE_SIZES):
            sk1 = ssd_constants.SCALES[idx] / ssd_constants.IMAGE_SIZE
            sk2 = ssd_constants.SCALES[idx + 1] / ssd_constants.IMAGE_SIZE
            sk3 = math.sqrt(sk1 * sk2)
            all_sizes = [(sk1, sk1), (sk3, sk3)]

            for alpha in ssd_constants.ASPECT_RATIOS[idx]:
                w, h = sk1 * math.sqrt(alpha), sk1 / math.sqrt(alpha)
                all_sizes.append((w, h))
                all_sizes.append((h, w))

            assert len(all_sizes) == ssd_constants.NUM_DEFAULTS[idx]

            for w, h in all_sizes:
                for i, j in it.product(range(feature_size), repeat=2):
                    cx, cy = (j + 0.5) / fk[idx], (i + 0.5) / fk[idx]
                    box = tuple(np.clip(k, 0, 1) for k in (cy, cx, h, w))
                    self.default_boxes.append(box)

        assert len(self.default_boxes) == ssd_constants.NUM_SSD_BOXES
        mlp_log.mlperf_print('max_samples', ssd_constants.NUM_SSD_BOXES)

        def to_ltrb(cy, cx, h, w):
            return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2

        # For IoU calculation
        self.default_boxes_ltrb = tuple(
            to_ltrb(*i) for i in self.default_boxes)
コード例 #2
0
  def train_and_eval(self,
                     train_steps,
                     local_batch_size,  # pylint: disable=unused-argument
                     num_threads=2):  # pylint: disable=unused-argument
    """Run the training loop on the TPU device."""
    tf.logging.info("LowLevelRunner: train for %d steps in total.",
                    train_steps)

    if train_steps % self.iterations != 0:
      tf.logging.warning(
          "train_steps %d is not divisible by iterations_per_loop %d",
          train_steps, self.iterations)
      train_steps = self.iterations * int(
          math.ceil(train_steps / self.iterations))

    # Train and eval/predict thread.
    def train_eval_thread_fn(sess, train_eval_op, steps):
      sess.run([train_eval_op],
               feed_dict={self.num_epochs_tensor: steps,
                          self.train_steps_tensor: self.iterations,
                          self.eval_steps_tensor: self.eval_steps})

    self.train_eval_thread = threading.Thread(
        target=train_eval_thread_fn,
        args=(self.sess, self.train_eval_op, train_steps // self.iterations))
    self.train_eval_thread.start()

    # Infeed thread.
    def infeed_thread_fn(sess,
                         eval_sess,
                         enqueue_ops,
                         eval_enqueue_ops,
                         eval_dataset_initializer):
      """Build and infeed session.run calls in a background thread."""
      for i in range(train_steps // self.iterations):
        mlp_log.mlperf_print(
            "block_start",
            None,
            metadata={
                "first_epoch_num": i + 1,
                "epoch_count": 1
            })
        tf.logging.info(
            "Start to infeed %d batches for training of epoch %d.",
            self.iterations, i)
        sess.run([enqueue_ops])
        eval_sess.run(eval_dataset_initializer)
        eval_sess.run([eval_enqueue_ops])

    self.infeed_thread = threading.Thread(
        target=infeed_thread_fn,
        args=(self.input_sess,
              self.eval_input_sess,
              self.enqueue_ops,
              self.eval_enqueue_ops,
              self.eval_dataset_initializer))
    time.sleep(240)
    mlp_log.mlperf_print(key="init_stop", value=None)
    mlp_log.mlperf_print(key="run_start", value=None)
    self.infeed_thread.start()
コード例 #3
0
def run_main(flags, default_hparams, estimator_fn):
    """Run main."""
    # Job
    jobid = flags.jobid
    utils.print_out("# Job id %d" % jobid)

    # Random
    random_seed = flags.random_seed
    if random_seed is not None and random_seed > 0:
        utils.print_out("# Set random seed to %d" % random_seed)
        random.seed(random_seed + jobid)
        np.random.seed(random_seed + jobid)
        tf.set_random_seed(random_seed)

    # Model output directory
    out_dir = flags.out_dir
    if out_dir and not tf.gfile.Exists(out_dir):
        utils.print_out("# Creating output directory %s ..." % out_dir)
        tf.gfile.MakeDirs(out_dir)

    # Load hparams.
    hparams = create_or_load_hparams(default_hparams, flags.hparams_path)

    # TODO(dehao) move init time closer to model construction if necessary.
    mlp_log.mlperf_print("init_start", None)

    # Train or Evaluation
    return estimator_fn(hparams)
コード例 #4
0
def learning_rate_schedule(params, global_step):
    """Handles learning rate scaling, linear warmup, and learning rate decay.

  Args:
    params: A dictionary that defines hyperparameters of model.
    global_step: A tensor representing current global step.

  Returns:
    A tensor representing current learning rate.
  """
    base_learning_rate = params['base_learning_rate']
    lr_warmup_step = params['lr_warmup_step']
    first_lr_drop_step = params['first_lr_drop_step']
    second_lr_drop_step = params['second_lr_drop_step']
    batch_size = (params['batch_size'] * params['num_shards']
                  if params['use_tpu'] else params['batch_size'])
    scaling_factor = batch_size / ssd_constants.DEFAULT_BATCH_SIZE
    mlp_log.mlperf_print('opt_learning_rate_warmup_factor', scaling_factor)
    mlp_log.mlperf_print('opt_learning_rate_warmup_steps', lr_warmup_step)
    adjusted_learning_rate = base_learning_rate * scaling_factor
    learning_rate = (tf.cast(global_step, dtype=tf.float32) /
                     lr_warmup_step) * adjusted_learning_rate
    lr_schedule = [[1.0, lr_warmup_step], [0.1, first_lr_drop_step],
                   [0.01, second_lr_drop_step]]
    for mult, start_global_step in lr_schedule:
        learning_rate = tf.where(global_step < start_global_step,
                                 learning_rate, adjusted_learning_rate * mult)
    return learning_rate
コード例 #5
0
def learning_rate_schedule(current_epoch):
  """Handles linear scaling rule, gradual warmup, and LR decay.

  The learning rate starts at 0, then it increases linearly per step.
  After 5 epochs we reach the base learning rate (scaled to account
    for batch size).
  After 30, 60 and 80 epochs the learning rate is divided by 10.
  After 90 epochs training stops and the LR is set to 0. This ensures
    that we train for exactly 90 epochs for reproducibility.

  Args:
    current_epoch: `Tensor` for current epoch.

  Returns:
    A scaled `Tensor` for current learning rate.
  """
  mlp_log.mlperf_print('base_learning_rate', FLAGS.base_learning_rate)
  scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)

  decay_rate = (scaled_lr * LR_SCHEDULE[0][0] *
                current_epoch / LR_SCHEDULE[0][1])
  for mult, start_epoch in LR_SCHEDULE:
    decay_rate = tf.where(current_epoch < start_epoch,
                          decay_rate, scaled_lr * mult)
  return decay_rate
コード例 #6
0
 def train(self):
     """Run the train loops and write a summary to directory."""
     mlp_log.mlperf_print(key='init_stop', value=None)
     mlp_log.mlperf_print(key='run_start', value=None)
     if self.use_tpu_estimator:
         self.runner.train(input_fn=self.input_fn,
                           max_steps=self.params['total_steps'])
     else:
         self.runner.train()
コード例 #7
0
def poly_rate_schedule(current_epoch, poly_rate=0.0):
    """Handles linear scaling rule, gradual warmup, and LR decay.

  The learning rate starts at 0, then it increases linearly per step.  After
  FLAGS.poly_warmup_epochs, we reach the base learning rate (scaled to account
  for batch size). The learning rate is then decayed using a polynomial rate
  decay schedule with power 2.0.

  Args:
    current_epoch: `Tensor` for current epoch.
    poly_rate: Polynomial decay rate.

  Returns:
    A scaled `Tensor` for current learning rate.
  """

    batch_size = FLAGS.train_batch_size
    if batch_size < 16384:
        plr = 10.0
        w_epochs = 5
    elif batch_size < 32768:
        plr = 25.0
        w_epochs = 5
    else:
        plr = 33.0
        w_epochs = 25

    # Override default poly learning rate and warmup epochs
    if poly_rate > 0.0:
        plr = poly_rate

    if FLAGS.lars_base_learning_rate > 0.0:
        plr = FLAGS.lars_base_learning_rate

    if FLAGS.lars_warmup_epochs > 0:
        w_epochs = FLAGS.lars_warmup_epochs

    mlp_log.mlperf_print('opt_base_learning_rate', plr)
    mlp_log.mlperf_print('opt_learning_rate_warmup_epochs', w_epochs)
    mlp_log.mlperf_print('lars_opt_end_learning_rate', 0.0001)

    wrate = (plr * current_epoch / w_epochs)
    w_steps = (w_epochs * FLAGS.num_train_images // batch_size)
    min_step = tf.constant(1, dtype=tf.int64)
    global_step = tf.train.get_or_create_global_step()
    decay_steps = tf.maximum(min_step, tf.subtract(global_step, w_steps))

    mlp_log.mlperf_print('lars_opt_learning_rate_decay_steps',
                         FLAGS.train_steps - w_steps + 1)
    mlp_log.mlperf_print('lars_opt_learning_rate_decay_poly_power', 2.0)

    poly_rate = tf.train.polynomial_decay(plr,
                                          decay_steps,
                                          FLAGS.train_steps - w_steps + 1,
                                          power=2.0)
    decay_rate = tf.where(current_epoch <= w_epochs, wrate, poly_rate)
    return decay_rate
コード例 #8
0
def init_lars_optimizer(current_epoch):
    """Initialize the LARS Optimizer."""

    mlp_log.mlperf_print('lars_opt_weight_decay', FLAGS.weight_decay)
    mlp_log.mlperf_print('lars_epsilon', 0.0)

    learning_rate = poly_rate_schedule(current_epoch, FLAGS.poly_rate)
    optimizer = tf.contrib.opt.LARSOptimizer(
        learning_rate,
        momentum=FLAGS.momentum,
        weight_decay=FLAGS.weight_decay,
        skip_list=['batch_normalization', 'bias'])
    return optimizer
コード例 #9
0
def learning_rate_schedule(hparams):
    """Learning rate schedule based on hparams."""
    mlp_log.mlperf_print(key="opt_base_learning_rate",
                         value=hparams.learning_rate_constant)
    mlp_log.mlperf_print(key="opt_learning_rate_warmup_steps",
                         value=hparams.learning_rate_warmup_steps)
    step_num = _global_step(hparams)
    schedule_string = hparams.learning_rate_schedule
    names = schedule_string.split("*")
    names = [name.strip() for name in names if name.strip()]
    ret = tf.constant(1.0)
    for name in names:
        ret *= learning_rate_factor(name, step_num, hparams)
    return ret
コード例 #10
0
ファイル: lars_util.py プロジェクト: shawwn/daxx-lightning
def init_lars_optimizer(current_epoch):
    """Initialize the LARS Optimizer."""

    lars_epsilon = FLAGS.lars_epsilon
    mlp_log.mlperf_print('lars_opt_weight_decay', FLAGS.weight_decay)
    mlp_log.mlperf_print('lars_epsilon', lars_epsilon)

    learning_rate = get_lars_lr(current_epoch)
    optimizer = tf.contrib.opt.LARSOptimizer(
        learning_rate,
        momentum=FLAGS.momentum,
        weight_decay=FLAGS.weight_decay,
        skip_list=['batch_normalization', 'bias'],
        epsilon=lars_epsilon)
    return optimizer, learning_rate
コード例 #11
0
 def infeed_thread_fn(sess, eval_sess, enqueue_ops, eval_enqueue_ops,
                      eval_dataset_initializer):
     """Build and infeed session.run calls in a background thread."""
     for i in range(train_steps // self.iterations):
         mlp_log.mlperf_print("block_start",
                              None,
                              metadata={
                                  "first_epoch_num": i + 1,
                                  "epoch_count": 1
                              })
         tf.logging.info(
             "Start to infeed %d batches for training of epoch %d.",
             self.iterations, i)
         sess.run([enqueue_ops])
         eval_sess.run(eval_dataset_initializer)
         eval_sess.run([eval_enqueue_ops])
        def log_eval_result_fn(results):
            """Log eval results."""
            cur_step, eval_results = results
            if cur_step == _STOP:
                return
            epoch = cur_step // self.params["steps_per_epoch"]
            with tf.Graph().as_default():
                summaries = []
                for metric in eval_results:
                    summaries.append(
                        tf.Summary.Value(tag=metric,
                                         simple_value=eval_results[metric]))
                    tf_summary = tf.Summary(value=list(summaries))
                    summary_writer.add_summary(tf_summary, cur_step)
                mlp_log.mlperf_print("eval_accuracy",
                                     eval_results["COCO/AP"],
                                     metadata={"epoch_num": epoch + 1})
                mlp_log.mlperf_print("eval_stop",
                                     None,
                                     metadata={"epoch_num": epoch + 1})

                if epoch in self.epoch_count:
                    epoch_count = self.epoch_count[epoch]
                else:
                    epoch_count = 1

                mlp_log.mlperf_print("block_stop",
                                     None,
                                     metadata={
                                         "first_epoch_num":
                                         epoch - epoch_count + 1,
                                         "epoch_count": epoch_count
                                     })

                self.log_epochs[epoch] = True
                if eval_results["COCO/AP"] >= ssd_constants.EVAL_TARGET:
                    self.run_success = True
                    if epoch < self.success_epoch:
                        self.success_epoch = epoch
                log_run_final = self.run_success
                for epoch in self.log_epochs:
                    if epoch < self.success_epoch and not self.log_epochs[
                            epoch]:
                        log_run_final = False
                        break
                # Log run_final when all the previous eval results are logged.
                if log_run_final and not self.log_run_success:
                    mlp_log.mlperf_print("run_stop",
                                         None,
                                         metadata={"status": "success"})
                    self.log_run_success = True
コード例 #13
0
 def get_predict_results(self, cur_epoch):
     """Run the predict loop on the TPU device."""
     for step in range(self.eval_steps):
         tf.logging.info(
             "TrainAndEvalLowLevelRunner: reading eval step %d results",
             step)
         predictions = {name: [] for name in self.outfeed_names}
         for outfeed_dict in self.train_eval_sess.run(self.dequeue_ops):
             for name, tensors in six.iteritems(outfeed_dict):
                 predictions[name].extend(tensors)
         if step == self.eval_steps - 1:
             # all predictions is read from device, async eval post-process starts.
             # next train on device also starts.
             mlp_log.mlperf_print("block_stop",
                                  None,
                                  metadata={
                                      "first_epoch_num": cur_epoch,
                                      "epoch_count": 1
                                  })
             mlp_log.mlperf_print("eval_start",
                                  None,
                                  metadata={"epoch_num": cur_epoch})
             tf.logging.info(
                 "TrainAndEvalLowLevelRunner: start eval epoch %d.",
                 cur_epoch)
             mlp_log.mlperf_print("block_start",
                                  None,
                                  metadata={
                                      "first_epoch_num": cur_epoch + 1,
                                      "epoch_count": 1
                                  })
         yield predictions
コード例 #14
0
    def after_run(self, run_context, run_values):  # pylint: disable=unused-argument
        """Runs evaluator."""
        step = np.asscalar(run_context.session.run(self._global_step_tensor))

        if self._timer.should_trigger_for_step(step):
            logging.info('Starting eval.')
            eval_results = self._evaluate(run_context.session, step)
            mlp_log.mlperf_print('eval_accuracy',
                                 float(eval_results[_EVAL_METRIC]),
                                 metadata={
                                     'epoch_num':
                                     max(step // self._steps_per_epoch - 1, 0)
                                 })

            # The ImageNet eval size is hard coded.
            if eval_results[_EVAL_METRIC] >= self._stop_threshold:
                self._run_success = True
                mlp_log.mlperf_print('run_stop',
                                     None,
                                     metadata={'status': 'success'})
                mlp_log.mlperf_print('run_final', None)
                run_context.request_stop()

        if step // self._steps_per_epoch == self._eval_every_epoch_from:
            self._timer = training.SecondOrStepTimer(
                every_steps=self._steps_per_epoch)
            self._timer.reset()
    def post_processing_thread_fn():
      """Run post-processing on CPU for predictions."""
      mlp_log.mlperf_print(
          "block_start", None, metadata={"first_epoch_num": 0,
                                         "epoch_count": 1})
      for cur_epoch in range(self.total_epoch):

        eval_begin = time.time()
       # Enables multi-processing to accelerate post-processing.
        eval_multiprocess.eval_multiprocessing(
            self.eval_steps, self.get_predict_results(cur_epoch),
            self.eval_metric, self.eval_params["eval_worker_count"])

        pred_end = time.time()
        tf.logging.info("prediction takes %d seconds.", pred_end - eval_begin)

        num_eval_samples, eval_results = self.eval_metric.evaluate()
        eval_end = time.time()
        tf.logging.info("COCO evaluates %d samples", num_eval_samples)
        if num_eval_samples != self.eval_params["eval_samples"]:
          tf.logging.info("COCO fails to evaluate all %d samples, exit!" %
                          self.eval_params["eval_samples"])
          self.run_success = False
          self.continue_train = False
          return
        tf.logging.info("one evaluation takes %d seconds",
                        eval_end - eval_begin)
        self.write_eval_summary(self.eval_summary_writer, eval_results,
                                cur_epoch * self.iterations_per_loop)
        tf.logging.info("AP: %s" % eval_results["AP"])
        tf.logging.info("mask_AP: %s" % eval_results["mask_AP"])
        # Eval epoch is 0-indexed (for MLPerf log parsing).
        mlp_log.mlperf_print(
            "eval_stop", None, metadata={"epoch_num": cur_epoch})
        # TODO(b/127959551): use both metrics once the bug is resolved.
        mlp_log.mlperf_print(
            "eval_accuracy", (float(eval_results["AP"]),
                              float(eval_results["mask_AP"])),
            metadata={"epoch_num": cur_epoch})

        if (eval_results["AP"] >= mask_rcnn_params.BOX_EVAL_TARGET and
            eval_results["mask_AP"] >= mask_rcnn_params.MASK_EVAL_TARGET):
          mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"})
          self.run_success = True
          self.continue_train = False
          return
コード例 #16
0
    def infeed_thread_fn(sess, train_enqueue_ops, eval_enqueue_ops, eval_init):
      """Start the infeed."""
      time.sleep(300)

      mlp_log.mlperf_print("init_stop", None)
      mlp_log.mlperf_print("run_start", None)
      for i in range(self.hparams.max_train_epochs):
        tf.logging.info("Infeed for epoch: %d", i + 1)
        mlp_log.mlperf_print(
            "block_start",
            None,
            metadata={
                "first_epoch_num": i + 1,
                "epoch_count": 1
            })
        mlp_log.mlperf_print("epoch_start", None, metadata={"epoch_num": i + 1})
        sess.run(eval_init)
        sess.run([train_enqueue_ops])
        sess.run([eval_enqueue_ops])
コード例 #17
0
    def __init__(self, optimizer_name, lr, hparams, use_tpu=False):  # pylint: disable=super-init-not-called
        tf.logging.info("Using optimizer %s", optimizer_name)

        mlp_log.mlperf_print(key="opt_name", value=optimizer_name)
        mlp_log.mlperf_print(key="opt_adam_beta_1",
                             value=hparams.optimizer_adam_beta1)
        mlp_log.mlperf_print(key="opt_adam_beta_2",
                             value=hparams.optimizer_adam_beta2)
        mlp_log.mlperf_print(key="opt_adam_epsilon",
                             value=hparams.optimizer_adam_epsilon)

        self._bfloat16_grads_all_reduce = hparams.bfloat16_grads_all_reduce

        if optimizer_name == "Adam":
            # We change the default epsilon for Adam.
            # Using LazyAdam as it's much faster for large vocabulary embeddings.
            self._opt = tf.contrib.opt.LazyAdamOptimizer(
                lr,
                beta1=hparams.optimizer_adam_beta1,
                beta2=hparams.optimizer_adam_beta2,
                epsilon=hparams.optimizer_adam_epsilon)
        elif optimizer_name == "Momentum":
            self._opt = tf.train.MomentumOptimizer(
                lr,
                momentum=hparams.optimizer_momentum_momentum,
                use_nesterov=hparams.optimizer_momentum_nesterov)
        elif optimizer_name == "TrueAdam":
            self._opt = tf.train.AdamOptimizer(
                lr,
                beta1=hparams.optimizer_adam_beta1,
                beta2=hparams.optimizer_adam_beta2,
                epsilon=hparams.optimizer_adam_epsilon)
        elif optimizer_name == "Adafactor":
            self._opt = adafactor.adafactor_optimizer_from_hparams(hparams, lr)
        # BEGIN GOOGLE-INTERNAL
        elif optimizer_name == "SM3":
            self._opt = sm3.SM3Optimizer(
                lr, momentum=hparams.optimizer_momentum_momentum)
        # END GOOGLE-INTERNAL
        else:
            self._opt = tf.contrib.layers.OPTIMIZER_CLS_NAMES[optimizer_name](
                lr)
コード例 #18
0
 def get_learning_rate(self, params, global_step):
   """Sets up learning rate schedule."""
   learning_rate = lr_policy.learning_rate_schedule(
       params['learning_rate'], params['lr_warmup_init'],
       params['lr_warmup_step'], params['first_lr_drop_step'],
       params['second_lr_drop_step'], global_step)
   mlp_log.mlperf_print(key='opt_base_learning_rate',
                        value=params['learning_rate'])
   mlp_log.mlperf_print(key='opt_learning_rate_warmup_steps',
                        value=params['lr_warmup_step'])
   mlp_log.mlperf_print(key='opt_learning_rate_warmup_factor',
                        value=params['lr_warmup_init']/params['learning_rate'])
   return learning_rate
コード例 #19
0
    def end(self, session):  # pylint: disable=unused-argument
        """Runs evaluator for final model."""
        # Only runs eval at the end if highest accuracy so far
        # is less than self._stop_threshold.
        if not self._run_success:
            step = np.asscalar(session.run(self._global_step_tensor))
            logging.info('Starting eval.')
            eval_results = self._evaluate(session, step)
            mlp_log.mlperf_print('eval_accuracy',
                                 float(eval_results[_EVAL_METRIC]),
                                 metadata={
                                     'epoch_num':
                                     max(step // self._steps_per_epoch - 1, 0)
                                 })
            if eval_results[_EVAL_METRIC] >= self._stop_threshold:
                mlp_log.mlperf_print('run_stop',
                                     None,
                                     metadata={'status': 'success'})
            else:
                mlp_log.mlperf_print('run_stop',
                                     None,
                                     metadata={'status': 'abort'})

            mlp_log.mlperf_print('run_final', None)
コード例 #20
0
    def evaluate(self, ckpt):
        """Performs evaluation against `ckpt` and writes a summary to directory."""
        current_step, num_epochs = self.get_step_and_epoch_number(ckpt)
        mlp_log.mlperf_print('eval_start',
                             None,
                             metadata={'epoch_num': num_epochs})
        eval_begin = time.time()
        if self.use_tpu_estimator:
            input_fn = functools.partial(self.input_fn,
                                         num_examples=self.eval_steps *
                                         self.params['eval_batch_size'])
            predictor = self.runner.predict(input_fn=input_fn,
                                            checkpoint_path=ckpt,
                                            yield_single_examples=False)
        else:
            predictor = self.runner.predict(checkpoint_path=ckpt,
                                            eval_steps=self.eval_steps)

        # Enables multi-processing to accelerate post-processing.
        eval_multiprocess.eval_multiprocessing(
            self.eval_steps, predictor, self.eval_metric,
            self.params['eval_worker_count'])

        pred_end = time.time()
        tf.logging.info('prediction takes %d seconds.', pred_end - eval_begin)
        num_eval_samples, eval_results = self.eval_metric.evaluate()

        eval_end = time.time()
        tf.logging.info('COCO evaluates %d samples', num_eval_samples)
        assert num_eval_samples == self.params['eval_samples']
        tf.logging.info('one evaluation takes %d seconds',
                        eval_end - eval_begin)
        self.write_summary(eval_results, current_step)
        tf.logging.info('AP: %s' % eval_results['AP'])
        tf.logging.info('mask_AP: %s' % eval_results['mask_AP'])
        mlp_log.mlperf_print('eval_stop',
                             None,
                             metadata={'epoch_num': num_epochs})
        # TODO(b/127959551): use both metrics once the bug is resolved.
        mlp_log.mlperf_print(
            'eval_accuracy',
            (float(eval_results['AP']), float(eval_results['mask_AP'])),
            metadata={'epoch_num': num_epochs})

        return eval_results
コード例 #21
0
def compute_bleu_summaries(hook_args):
    """Compute BLEU core summaries using the decoder output.

  Args:
    hook_args: DecodeHookArgs namedtuple
  Returns:
    A list of tf.Summary values if hook_args.hparams contains the
    reference file and the translated file.
  """
    outputs, references = [], []
    for output, reference in hook_args.predictions:
        outputs.append(output)
        references.append(reference)

    decode_hparams = hook_args.decode_hparams

    values = []
    bleu = 100 * bleu_hook.bleu_wrapper(references, outputs)
    values.append(tf.Summary.Value(tag="BLEU", simple_value=bleu))
    tf.logging.info("BLEU = %6.2f" % (bleu))
    if hook_args.hparams.mlperf_mode:
        current_step = decode_hparams.mlperf_decode_step
        mlp_log.mlperf_print(
            "eval_stop",
            None,
            metadata={
                "epoch_num":
                max(current_step // decode_hparams.iterations_per_loop, 1)
            })
        mlp_log.mlperf_print(
            "eval_accuracy",
            bleu,
            metadata={
                "epoch_num":
                max(current_step // decode_hparams.iterations_per_loop, 1)
            })

    if bleu >= decode_hparams.mlperf_threshold:
        mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"})
        decode_hparams.set_hparam("mlperf_success", True)

    return values
コード例 #22
0
    def train(self, max_steps=None):
        """Train for max_steps."""
        mlp_log.mlperf_print(key="init_stop", value=None)
        mlp_log.mlperf_print(key="run_start", value=None)
        mlp_log.mlperf_print("block_start",
                             None,
                             metadata={
                                 "first_epoch_num": 1,
                                 "epoch_count": 1
                             })

        if self._hparams.train_with_low_level_api:
            self._trunner.train(self._hparams.train_steps,
                                self._hparams.batch_size)
            self._trunner.shutdown()
        else:
            self._estimator.train(self._train_spec.input_fn,
                                  hooks=self._train_spec.hooks,
                                  max_steps=max_steps
                                  or self._train_spec.max_steps)
コード例 #23
0
 def infeed_thread_fn():
     """Build and infeed session.run calls in a background thread."""
     # Starts the clock.
     time.sleep(60)
     mlp_log.mlperf_print(key="init_stop", value=None)
     mlp_log.mlperf_print(key="run_start", value=None)
     mlp_log.mlperf_print("block_start",
                          None,
                          metadata={
                              "first_epoch_num": 0,
                              "epoch_count": 1
                          })
     for cur_epoch in range(self.total_epoch):
         tf.logging.info("Start to infeed train batches for epoch %d",
                         cur_epoch)
         self.input_sess.run([self.enqueue_ops])
         tf.logging.info("Start to infeed eval batches for epoch %d",
                         cur_epoch)
         self.input_sess.run([self.eval_enqueue_ops])
     tf.logging.info("infeed thread exited.")
コード例 #24
0
    def continuous_decode_on_eval_data(self):
        """Decode from dataset on new checkpoint."""
        if self._hparams.mlperf_mode:
            ckpt_generator = next_undecoded_checkpoint(self._hparams.model_dir)
        else:
            ckpt_generator = next_checkpoint(self._hparams.model_dir)

        for ckpt in ckpt_generator:
            current_step = int(os.path.basename(ckpt).split("-")[1])
            tf.logging.info("Decoding step %d" % current_step)
            # Skip checkpoint 0.
            if current_step == 0:
                continue
            # Decode the latest checkpoint by default.
            checkpoint_path = None
            if self._hparams.mlperf_mode:
                self._decode_hparams.mlperf_decode_step = current_step
                checkpoint_path = ckpt

            mlp_log.mlperf_print(
                "eval_start",
                None,
                metadata={
                    "epoch_num":
                    max(
                        current_step //
                        self._decode_hparams.iterations_per_loop, 1)
                })
            self.decode(dataset_split=tf.estimator.ModeKeys.EVAL,
                        checkpoint_path=checkpoint_path)
            if self._hparams.mlperf_mode and self._decode_hparams.mlperf_success:
                mlp_log.mlperf_print("run_stop",
                                     None,
                                     metadata={"status": "success"})
                break

        if self._hparams.mlperf_mode and not self._decode_hparams.mlperf_success:
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "abort"})
コード例 #25
0
    def train_and_eval(self):
        """Performs distributed model eval and writes a summary to directory."""
        self.run_success = False
        self.continue_train = True

        # queues for predictions post-processing.
        def post_processing_thread_fn():
            """Run post-processing on CPU for predictions."""
            for cur_epoch in range(self.total_epoch):

                eval_begin = time.time()
                # Enables multi-processing to accelerate post-processing.
                eval_multiprocess.eval_multiprocessing(
                    self.eval_steps, self.get_predict_results(cur_epoch),
                    self.eval_metric, self.eval_params["eval_worker_count"])

                pred_end = time.time()
                tf.logging.info("prediction takes %d seconds.",
                                pred_end - eval_begin)

                num_eval_samples, eval_results = self.eval_metric.evaluate()
                eval_end = time.time()
                tf.logging.info("COCO evaluates %d samples", num_eval_samples)
                if num_eval_samples != self.eval_params["eval_samples"]:
                    tf.logging.info(
                        "COCO fails to evaluate all %d samples, exit!" %
                        self.eval_params["eval_samples"])
                    self.run_success = False
                    self.continue_train = False
                    return
                tf.logging.info("one evaluation takes %d seconds",
                                eval_end - eval_begin)
                self.write_eval_summary(self.eval_summary_writer, eval_results,
                                        cur_epoch * self.iterations_per_loop)
                tf.logging.info("AP: %s" % eval_results["AP"])
                tf.logging.info("mask_AP: %s" % eval_results["mask_AP"])
                # Eval epoch is 0-indexed (for MLPerf log parsing).
                mlp_log.mlperf_print("eval_stop",
                                     None,
                                     metadata={"epoch_num": cur_epoch})
                # TODO(b/127959551): use both metrics once the bug is resolved.
                mlp_log.mlperf_print("eval_accuracy", (float(
                    eval_results["AP"]), float(eval_results["mask_AP"])),
                                     metadata={"epoch_num": cur_epoch})

                if (eval_results["AP"] >= mask_rcnn_params.BOX_EVAL_TARGET
                        and eval_results["mask_AP"] >=
                        mask_rcnn_params.MASK_EVAL_TARGET):
                    mlp_log.mlperf_print("run_stop",
                                         None,
                                         metadata={"status": "success"})
                    self.run_success = True
                    self.continue_train = False
                    return

        # Run predict post processing thread on the background.
        post_processing_thread = threading.Thread(
            target=post_processing_thread_fn)
        post_processing_thread.start()
        if self.train_params["all_in_one_session"]:
            tf.logging.info(
                "TrainAndEvalLowLevelRunner: start train_eval sessions")
            self.train_eval_sess.run(self.train_eval_op)
        else:
            if self.train_params["train_and_eval_save_checkpoint"]:
                ckpt_saver = runner_utils.AsyncCheckpointSaver(
                    _MAX_NUM_CHECKPOINT_THREADS, self.saver, self.model_dir,
                    self.train_eval_sess)
            cur_epoch = 0
            while cur_epoch < self.total_epoch and self.continue_train:
                tf.logging.info(
                    "TrainAndEvalLowLevelRunner: start train epoch: %d",
                    cur_epoch)
                start = time.time()
                self.train_eval_sess.run(self.train_eval_op)
                end = time.time()
                self.write_summary(summary_writer=self.summary_writer,
                                   graph=self.train_eval_graph,
                                   global_step=cur_epoch *
                                   self.iterations_per_loop,
                                   elapsed_time=end - start,
                                   elapsed_steps=self.iterations_per_loop,
                                   trained_examples=self.
                                   train_params["num_examples_per_epoch"])
                if self.train_params["train_and_eval_save_checkpoint"]:
                    ckpt_saver.checkpoint(cur_epoch * self.iterations_per_loop)
                if self.run_success or not self.continue_train:
                    break
                cur_epoch += 1

        post_processing_thread.join()
        if not self.run_success:
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "abort"})
コード例 #26
0
    def train_and_eval(self, output_summaries=False, enable_tracing=True):
        """Run the Train steps on the TPU device."""
        if output_summaries:
            output_dir = os.path.join(FLAGS.model_dir, "eval", self.tpu_name)
            tf.gfile.MakeDirs(output_dir)
            # Summary writer writes out eval metrics.
            summary_writer = tf.summary.FileWriter(output_dir)
            if FLAGS.save_graphs:
                summary_writer.add_graph(self.graph)
                summary_writer.add_graph(self.input_graph)
                summary_writer.add_graph(self.eval_input_graph)
                summary_writer.add_graph(self.eval_output_graph)

        def infeed_thread_fn():
            """Build and infeed session.run calls in a background thread."""
            # Build infeed sesssion
            # Run infeed session.run calls
            tf.logging.info("Start infeed thread")
            for _ in range(self.train_steps // self.iterations):
                self.input_sess.run([self.enqueue_ops])
                self.eval_input_sess.run([self.eval_enqueue_ops])

        if False:
            self.infeed_thread = threading.Thread(target=infeed_thread_fn)
            self.infeed_thread.start()

        # Gather trace for the first few steps.
        if enable_tracing:
            self.launch_profiler()

        self.cur_step = 0
        success = False

        def enq(self, run=True):
            if self.infeed_thread is None:
                tf.logging.info("TrainAndEvalRunner: input_sess enqueue")
                self.input_sess.run([self.enqueue_ops])
                self.eval_input_sess.run([self.eval_enqueue_ops])
                tf.logging.info("TrainAndEvalRunner: enqueue (done)")
                if run:
                    tf.logging.info("TrainAndEvalRunner: train_eval_op...")
                    result = self.sess.run([self.train_eval_op])
                    tf.logging.info(
                        "TrainAndEvalRunner: train_eval_op... (done)")
                    return result

        def checkpoint_thread_fn(tpu_name, saver, sess, step):
            name = ''.join(['_' if not c.isalnum() else c for c in tpu_name])
            if FLAGS.export_dir is None:
                tf.logging.info("Not model %d: %s (FLAGS.export_dir is unset)",
                                step, name)
            else:
                name = FLAGS.export_dir + "/model-%s.ckpt-%d" % (name, step)
                tf.logging.info("Saving model %d: %s", step, name)
                saver.save(sess, name)

        @tflex.register_command
        def save():
            checkpoint_thread_fn(self.tpu_name, self.saver, self.sess,
                                 self.cur_step)

        # take care of the first JIT
        enq(self, run=False)
        while self.cur_step < self.train_steps or True:
            tflex.check_commands()
            if tflex.should_quit():
                import pdb
                pdb.set_trace()
                break
            self.start = time.time()
            tf.logging.info("TrainAndEvalRunner: start next %d steps",
                            self.iterations)
            self.cur_step = self.coordinator.claim(self.iterations)
            self.sess.run(self.global_step_init,
                          {self.global_step_in: self.cur_step})
            epoch = self.cur_step // self.steps_per_epoch - 1
            mlp_log.mlperf_print("block_start",
                                 None,
                                 metadata={
                                     "first_epoch_num": epoch + 1,
                                     "epoch_count": 4
                                 })
            self.step_loss = enq(self)
            self.eval_results = self.eval(self.eval_steps)
            self.end = time.time()
            self.step_time = self.end - self.start
            self.examples_sec = self.iterations * self.cfg[
                'train_batch_size'] / self.step_time
            self.eval_results['examples_sec'] = self.examples_sec
            self.eval_results['step_time'] = self.step_time
            if self.step_loss is not None:
                self.eval_results['loss'] = self.step_loss[0]
            if 'global_step' in self.eval_results:
                self.eval_results[
                    'global_step_sec'] = self.iterations / self.step_time
            tf.logging.info(
                "TrainAndEvalRunner ({}): step {} step time {} sec {} examples/sec"
                .format(self.tpu_name, self.cur_step, self.step_time,
                        self.examples_sec))
            # Run eval.
            # Write out summary to tensorboard.
            if output_summaries:
                with tf.Graph().as_default():
                    summaries = []
                    for metric in self.eval_results:
                        summaries.append(
                            tf.Summary.Value(
                                tag=metric,
                                simple_value=self.eval_results[metric]))
                        tf_summary = tf.Summary(value=list(summaries))
                        summary_writer.add_summary(tf_summary, self.cur_step)

                def flush(i):
                    tf.logging.info("Flushing summaries...")
                    start = time.time()
                    summary_writer.flush()
                    end = time.time()
                    tf.logging.info("Flushing summaries (done in %.2fs)",
                                    (end - start))

                if self.flush_summaries_thread is not None and self.flush_summaries_thread.is_alive(
                ):
                    start = time.time()
                    self.flush_summaries_thread.join()
                    end = time.time()
                    tf.logging.info(
                        "Flushing summaries [BLOCKED] (done in %.2fs)",
                        (end - start))
                self.flush_summaries_thread = dispatch([0], flush)[0]
            # MLPerf logging for eval results.
            mlp_log.mlperf_print("eval_accuracy",
                                 float(self.eval_results["top_1_accuracy"]),
                                 metadata={"epoch_num": epoch + 1})

            mlp_log.mlperf_print("block_stop",
                                 None,
                                 metadata={"first_epoch_num": epoch + 1})
            tf.logging.info("Eval results at step %d: %s", self.cur_step,
                            self.eval_results)
            if self.eval_results["top_1_accuracy"] >= FLAGS.stop_threshold:
                success = True
                if FLAGS.export_dir is not None:
                    self.checkpoint_thread = threading.Thread(
                        target=checkpoint_thread_fn,
                        args=(self.tpu_name, self.saver, self.sess,
                              self.cur_step))
                    self.checkpoint_thread.start()
                mlp_log.mlperf_print("run_stop",
                                     None,
                                     metadata={"status": "success"})
                import pdb
                pdb.set_trace()
                break

            if enable_tracing and self.cur_step > self.train_steps // 4:
                self.launch_profiler()
                enable_tracing = False

        if not success:
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "abort"})

        mlp_log.mlperf_print("run_final", None)

        if output_summaries:
            summary_writer.close()
コード例 #27
0
def main(argv):
    del argv  # Unused.

    # TODO(b/132208296): remove this workaround that uses control flow v2.
    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True

    tpu = FLAGS.tpu or FLAGS.master
    tpu_cluster_resolver = runner_utils.create_tpu_cluster_resolver(
        FLAGS.use_tpu, tpu, FLAGS.tpu_zone, FLAGS.gcp_project)
    if tpu_cluster_resolver:
        tpu_grpc_url = tpu_cluster_resolver.get_master()
        tf.Session.reset(tpu_grpc_url)

    # Check data path
    run_train = FLAGS.mode in ('train', 'train_and_eval')
    if run_train and FLAGS.training_file_pattern is None:
        raise RuntimeError(
            'You must specify --training_file_pattern for training.')
    run_eval = FLAGS.mode in ('eval', 'train_and_eval') or (
        FLAGS.mode == 'train' and FLAGS.eval_after_training)
    if run_eval:
        if FLAGS.validation_file_pattern is None:
            raise RuntimeError('You must specify --validation_file_pattern '
                               'for evaluation.')
        if FLAGS.val_json_file is None:
            raise RuntimeError(
                'You must specify --val_json_file for evaluation.')

    # Parse hparams
    hparams = mask_rcnn_params.default_hparams()
    hparams.parse(FLAGS.hparams)

    # The following is for spatial partitioning. `features` has one tensor while
    # `labels` has 4 + (`max_level` - `min_level` + 1) * 2 tensors. The input
    # partition is performed on `features` and all partitionable tensors of
    # `labels`, see the partition logic below.
    # Note: In the below code, TPUEstimator uses both `shard` and `replica` (with
    # the same meaning).
    # Note that spatial partition is part of the model-parallelism optimization.
    # See core_assignment_utils.py for more details about model parallelism.
    if FLAGS.input_partition_dims:
        labels_partition_dims = {
            'gt_boxes': None,
            'gt_classes': None,
            'cropped_gt_masks': None,
        }
        for level in range(hparams.get('min_level'),
                           hparams.get('max_level') + 1):
            labels_partition_dims['box_targets_%d' % level] = None
            labels_partition_dims['score_targets_%d' % level] = None
        num_cores_per_replica = int(np.prod(FLAGS.input_partition_dims))
        image_partition_dims = [
            FLAGS.input_partition_dims[i] for i in [1, 0, 2]
        ] if hparams.get('transpose_input') else FLAGS.input_partition_dims
        features_partition_dims = {
            'images': image_partition_dims,
            'source_ids': None,
            'image_info': None,
        }
        input_partition_dims = [features_partition_dims, labels_partition_dims]
        num_shards = FLAGS.num_cores // num_cores_per_replica
    else:
        num_cores_per_replica = None
        input_partition_dims = None
        num_shards = FLAGS.num_cores

    params = dict(hparams.values(),
                  num_shards=num_shards,
                  num_cores_per_replica=num_cores_per_replica,
                  use_tpu=FLAGS.use_tpu,
                  resnet_checkpoint=FLAGS.resnet_checkpoint,
                  val_json_file=FLAGS.val_json_file,
                  model_dir=FLAGS.model_dir)

    tpu_config = tf.contrib.tpu.TPUConfig(
        params['iterations_per_loop'],
        num_shards=num_shards,
        num_cores_per_replica=params['num_cores_per_replica'],
        input_partition_dims=input_partition_dims,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.
        PER_HOST_V2,
        tpu_job_name=FLAGS.tpu_job_name,
    )

    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        log_step_count_steps=params['iterations_per_loop'],
        tpu_config=tpu_config,
        save_checkpoints_steps=params['iterations_per_loop'],
    )

    train_replicas_per_worker = (
        params['cores_per_worker'] // params['num_cores_per_replica']
    ) if params['num_cores_per_replica'] else params['cores_per_worker']
    train_params = dict(
        params,
        replicas_per_worker=train_replicas_per_worker,
    )
    eval_params = dict(
        params,
        input_rand_hflip=False,
        resnet_checkpoint=None,
        is_training_bn=False,
        transpose_input=False,
    )

    # MLPerf logging.
    mlp_log.mlperf_print(key='init_start', value=None)
    mlp_log.mlperf_print(key='global_batch_size',
                         value=params['train_batch_size'])
    runner = None
    if run_train and run_eval:
        if params['train_use_tpu_estimator'] or params[
                'eval_use_tpu_estimator']:
            raise RuntimeError(
                'train_and_eval runner does not support TPUEstimator.')
        dist_eval_params = dict(
            eval_params,
            replicas_per_worker=train_replicas_per_worker,
        )
        runner = mask_rcnn_runner.TrainEvalRunner(
            model_fn=mask_rcnn_model.MaskRcnnModelFn(),
            input_fn=dataloader.InputReader(FLAGS.training_file_pattern,
                                            mode=tf.estimator.ModeKeys.TRAIN,
                                            use_fake_data=FLAGS.use_fake_data),
            eval_input_fn=dataloader.InputReader(
                FLAGS.validation_file_pattern,
                mode=tf.estimator.ModeKeys.PREDICT,
                distributed_eval=True),
            eval_metric=coco_metric.EvaluationMetric(FLAGS.val_json_file,
                                                     use_cpp_extension=True),
            train_params=train_params,
            eval_params=dist_eval_params,
            run_config=run_config)
    elif run_train:
        # Check low-level train runner compatibility.
        if not params['train_use_tpu_estimator']:
            if FLAGS.mode == 'train_and_eval':
                raise RuntimeError(
                    'Low level train runner does not support mode '
                    'train_and_eval yet.')
        train_params = dict(
            params,
            replicas_per_worker=train_replicas_per_worker,
        )
        runner = mask_rcnn_runner.TrainRunner(
            model_fn=mask_rcnn_model.MaskRcnnModelFn(),
            input_fn=dataloader.InputReader(FLAGS.training_file_pattern,
                                            mode=tf.estimator.ModeKeys.TRAIN,
                                            use_fake_data=FLAGS.use_fake_data),
            params=train_params,
            run_config=run_config,
            use_tpu_estimator=train_params['train_use_tpu_estimator'])
    else:
        sidecar_eval_params = dict(
            eval_params,
            # sidecar eval only uses one worker and does not use spatial partition.
            replicas_per_worker=FLAGS.num_cores,
        )
        runner = mask_rcnn_runner.EvalRunner(
            mask_rcnn_model.MaskRcnnModelFn(),
            dataloader.InputReader(FLAGS.validation_file_pattern,
                                   mode=tf.estimator.ModeKeys.PREDICT),
            coco_metric.EvaluationMetric(FLAGS.val_json_file,
                                         use_cpp_extension=True),
            sidecar_eval_params,
            run_config,
            use_tpu_estimator=sidecar_eval_params['eval_use_tpu_estimator'])

    if FLAGS.mode == 'train':
        runner.train()
    elif FLAGS.mode == 'eval':

        def terminate_eval():
            tf.logging.info(
                'Terminating eval after %d seconds of no checkpoints' %
                FLAGS.eval_timeout)
            return True

        run_success = False
        # Run evaluation when there's a new checkpoint
        for ckpt in tf.contrib.training.checkpoints_iterator(
                params['model_dir'],
                min_interval_secs=FLAGS.min_eval_interval,
                timeout=FLAGS.eval_timeout,
                timeout_fn=terminate_eval):

            tf.logging.info('Starting to evaluate.')
            try:

                eval_results = runner.evaluate(ckpt)
                current_step, _ = runner.get_step_and_epoch_number(ckpt)

                if (eval_results['AP'] >= mask_rcnn_params.BOX_EVAL_TARGET
                        and eval_results['mask_AP'] >=
                        mask_rcnn_params.MASK_EVAL_TARGET):
                    mlp_log.mlperf_print(key='run_stop',
                                         metadata={'status': 'success'})
                    run_success = True
                    break

                if int(current_step) >= params['total_steps']:
                    tf.logging.info(
                        'Evaluation finished after training step %d' %
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint' %
                    ckpt)
        if not run_success:
            mlp_log.mlperf_print(key='run_stop',
                                 metadata={'status': 'aborted'})

    elif FLAGS.mode == 'train_and_eval':
        runner.train_and_eval()
    else:
        tf.logging.info('Mode not found.')
コード例 #28
0
                    tf.logging.info(
                        "Evaluation finished but failed to reach target score."
                    )
                    break

            except tf.errors.NotFoundError:
                tf.logging.info(
                    "Checkpoint %s no longer exists, skipping checkpoint" %
                    ckpt)


if __name__ == "__main__":
    tf.logging.set_verbosity(tf.logging.INFO)
    nmt_parser = argparse.ArgumentParser()
    add_arguments(nmt_parser)
    FLAGS, unparsed = nmt_parser.parse_known_args()
    mlp_log.mlperf_print("global_batch_size", FLAGS.batch_size)
    mlp_log.mlperf_print("opt_learning_rate_alt_decay_func", "True")
    mlp_log.mlperf_print("opt_base_learning_rate", FLAGS.learning_rate)
    mlp_log.mlperf_print("opt_learning_rate_decay_interval",
                         FLAGS.decay_interval)
    mlp_log.mlperf_print("opt_learning_rate_decay_factor", FLAGS.decay_factor)
    mlp_log.mlperf_print("opt_learning_rate_decay_steps", FLAGS.decay_steps)
    mlp_log.mlperf_print("opt_learning_rate_remain_steps", FLAGS.decay_start)
    mlp_log.mlperf_print("opt_learning_rate_alt_warmup_func",
                         FLAGS.warmup_scheme)
    mlp_log.mlperf_print("opt_learning_rate_warmup_steps", FLAGS.warmup_steps)
    mlp_log.mlperf_print("max_sequence_length", FLAGS.src_max_len)

    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
    def train_and_eval(self, train_steps):
        """Run the Train and Eval loop on the TPU device."""
        output_dir = os.path.join(FLAGS.model_dir, "eval")
        tf.gfile.MakeDirs(output_dir)
        # Summary writer writes out eval metrics.
        summary_writer = tf.summary.FileWriter(output_dir)
        self.run_success = False

        def log_eval_result_fn(results):
            """Log eval results."""
            cur_step, eval_results = results
            if cur_step == _STOP:
                return
            epoch = cur_step // self.params["steps_per_epoch"]
            with tf.Graph().as_default():
                summaries = []
                for metric in eval_results:
                    summaries.append(
                        tf.Summary.Value(tag=metric,
                                         simple_value=eval_results[metric]))
                    tf_summary = tf.Summary(value=list(summaries))
                    summary_writer.add_summary(tf_summary, cur_step)
                mlp_log.mlperf_print("eval_accuracy",
                                     eval_results["COCO/AP"],
                                     metadata={"epoch_num": epoch + 1})
                mlp_log.mlperf_print("eval_stop",
                                     None,
                                     metadata={"epoch_num": epoch + 1})

                if epoch in self.epoch_count:
                    epoch_count = self.epoch_count[epoch]
                else:
                    epoch_count = 1

                mlp_log.mlperf_print("block_stop",
                                     None,
                                     metadata={
                                         "first_epoch_num":
                                         epoch - epoch_count + 1,
                                         "epoch_count": epoch_count
                                     })

                self.log_epochs[epoch] = True
                if eval_results["COCO/AP"] >= ssd_constants.EVAL_TARGET:
                    self.run_success = True
                    if epoch < self.success_epoch:
                        self.success_epoch = epoch
                log_run_final = self.run_success
                for epoch in self.log_epochs:
                    if epoch < self.success_epoch and not self.log_epochs[
                            epoch]:
                        log_run_final = False
                        break
                # Log run_final when all the previous eval results are logged.
                if log_run_final and not self.log_run_success:
                    mlp_log.mlperf_print("run_stop",
                                         None,
                                         metadata={"status": "success"})
                    self.log_run_success = True

        tf.logging.info(
            "TrainAndEvalLowLevelRunner: train for %d steps in total",
            train_steps)
        if train_steps % self.iterations != 0:
            tf.logging.warning(
                "train_steps %d is not divisible by iterations_per_loop %d",
                train_steps, self.iterations)
            train_steps = self.iterations * int(
                math.ceil(train_steps / self.iterations))

        # Start train and eval op on the background.
        def train_eval_thread_fn(sess, train_eval_op):
            sess.run([train_eval_op])

        train_eval_thread = threading.Thread(target=train_eval_thread_fn,
                                             args=(self.sess,
                                                   self.train_eval_op))
        train_eval_thread.start()

        # pylint: disable=line-too-long
        q_in = multiprocessing.Queue(maxsize=ssd_constants.QUEUE_SIZE)
        q_out = multiprocessing.Queue(maxsize=ssd_constants.QUEUE_SIZE)
        processes = [
            multiprocessing.Process(target=predict_post_processing,
                                    args=(q_in, q_out))
            for _ in range(self.num_multiprocessing_workers)
        ]
        # pylint: enable=line-too-long

        time.sleep(self.sleep_seconds)
        mlp_log.mlperf_print("init_stop", None)
        mlp_log.mlperf_print("run_start", None)

        for p in processes:
            p.start()
        self.infeed_thread.start()

        def log_eval_results_fn():
            result = q_out.get()
            cur_step, _ = result
            while cur_step != _STOP:
                log_eval_result_fn(result)
                result = q_out.get()
                cur_step, _ = result

        log_eval_result_thread = threading.Thread(target=log_eval_results_fn)
        log_eval_result_thread.start()

        cur_step = 0
        current_epoch = 0
        # Train and eval loop.
        while cur_step < train_steps:
            if self.run_success:
                break
            tf.logging.info("TrainAndEvalLowLevelRunner: start train step:%d",
                            cur_step)
            cur_step += self.iterations
            current_epoch = cur_step // self.params["steps_per_epoch"]
            if self.run_success:
                break
            if self.params[
                    "eval_every_checkpoint"] or current_epoch in self.eval_epochs:
                if current_epoch in self.epoch_count:
                    epoch_count = self.epoch_count[current_epoch]
                else:
                    epoch_count = 1
                mlp_log.mlperf_print("block_start",
                                     None,
                                     metadata={
                                         "first_epoch_num":
                                         current_epoch - epoch_count + 1,
                                         "epoch_count":
                                         epoch_count
                                     })
                mlp_log.mlperf_print("eval_start",
                                     None,
                                     metadata={"epoch_num": current_epoch + 1})
                # Run predict on device.
                start = time.time()
                predictions = list(self.predict())
                end = time.time()
                tf.logging.info(
                    "TrainAndEvalRunner: step {} step time {} sec".format(
                        cur_step, end - start))
                # Run predict post processing.
                q_in.put((cur_step, predictions))

        train_eval_thread.join()
        # Turn off predict thread.
        for _ in processes:
            q_in.put((_STOP, None))

        for p in processes:
            p.join(timeout=self.sleep_seconds)

        q_out.put((_STOP, None))
        log_eval_result_thread.join()

        # Clear out all the queues to avoid deadlock.
        while not q_out.empty():
            log_eval_result_fn(q_out.get())
        while not q_in.empty():
            q_in.get()

        summary_writer.close()
        if not self.run_success:
            mlp_log.mlperf_print("run_stop",
                                 None,
                                 metadata={"status": "abort"})
    def initialize(self, input_fn, eval_input_fn, model_fn, params):
        """Build graph and do initialization for training."""
        tf.logging.info("TrainAndEvalLowLevelRunner: initialize method")
        mlp_log.mlperf_print("init_start", None)

        self.params = params
        self.build_enqueue_ops(input_fn, params, host_id=0)

        def infeed_thread_fn():
            """Build and infeed session.run calls in a background thread."""
            # Initialize dataset variables
            for i in range(self.max_train_iterations):
                tf.logging.info(
                    "TrainAndEvalRunner: start infeed for %d steps",
                    self.iterations)
                self.input_sess.run([self.enqueue_ops])
                if self.params[
                        "eval_every_checkpoint"] or i in self.eval_iterations:
                    self.input_sess.run(self.eval_dataset_initializer)
                    self.input_sess.run([self.eval_enqueue_ops])

        def tpu_train_step(loss):
            """Generate the TPU graph."""
            del loss
            values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.feature_structure, values)
            features = unflattened_inputs["features"]
            labels = unflattened_inputs["labels"]
            estimator_spec = model_fn(features, labels,
                                      tf.estimator.ModeKeys.TRAIN, params)
            loss, train_op = estimator_spec.loss, estimator_spec.train_op
            self.scaffold_fn = estimator_spec.scaffold_fn
            with tf.control_dependencies([train_op]):
                return tf.identity(loss)

        def train_loop():
            return training_loop.repeat(self.iterations, tpu_train_step,
                                        [_INITIAL_LOSS])

        # Start the build of the train graph.
        self.train_loop = train_loop

        for i in range(1, self.num_hosts):
            self.build_enqueue_ops(input_fn, params, host_id=i)

        # Init for eval.
        self.initialize_eval(eval_input_fn, model_fn, params)

        with self.graph.as_default():
            if self.scaffold_fn:
                self.scaffold_fn()
            global_initializer = tf.global_variables_initializer()
            local_initializer = tf.local_variables_initializer()
            graph_io.write_graph(self.graph.as_graph_def(add_shapes=True),
                                 FLAGS.model_dir, "graph.pbtxt")

        # Build tpu train model session and initialize graph
        self.sess = tf.Session(self.master,
                               graph=self.graph,
                               config=self.session_config)
        self.input_sess = tf.Session(self.master,
                                     graph=self.input_graph,
                                     config=self.session_config)

        self.sess.run(global_initializer)
        self.sess.run(local_initializer)
        self.input_sess.run(self.dataset_initializer)
        self.input_sess.run(self.eval_dataset_initializer)

        # Complete infeed graph generation.
        self.infeed_thread = threading.Thread(target=infeed_thread_fn)
        # Compile.
        self.sess.run([self.train_eval_compile_op])