예제 #1
0
def main(argv):
  del argv  # Unused.

  # TODO(b/132208296): remove this workaround that uses control flow v2.
  control_flow_util.ENABLE_CONTROL_FLOW_V2 = True

  # Parse hparams
  hparams = mask_rcnn_params.default_hparams()
  hparams.parse(FLAGS.hparams)

  params = dict(
      hparams.values(),
      transpose_input=False if FLAGS.input_partition_dims is not None else True,
      resnet_checkpoint=FLAGS.resnet_checkpoint,
      val_json_file=FLAGS.val_json_file,
      num_cores_per_replica=int(np.prod(FLAGS.input_partition_dims))
      if FLAGS.input_partition_dims else 1,
      replicas_per_host=FLAGS.replicas_per_host)

  # MLPerf logging.
  mlp_log.mlperf_print(key='cache_clear', value=True)
  mlp_log.mlperf_print(key='init_start', value=None)
  mlp_log.mlperf_print(key='global_batch_size', value=FLAGS.train_batch_size)
  mlp_log.mlperf_print(key='train_samples', value=FLAGS.num_examples_per_epoch)
  mlp_log.mlperf_print(key='eval_samples', value=FLAGS.eval_samples)
  mlp_log.mlperf_print(
      key='min_image_size', value=params['short_side_image_size'])
  mlp_log.mlperf_print(
      key='max_image_size', value=params['long_side_max_image_size'])
  mlp_log.mlperf_print(key='num_image_candidates',
                       value=params['rpn_post_nms_topn'])

  train_steps = (
      FLAGS.num_epochs * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size)
  eval_steps = int(math.ceil(float(FLAGS.eval_samples) / FLAGS.eval_batch_size))
  if eval_steps > 0:
    # The eval dataset is not evenly divided. Adding step by one will make sure
    # all eval samples are covered.
    # TODO(b/151732586): regenerate the eval dataset to make all hosts get the
    #                    same amount of work.
    eval_steps += 1
  runner = train_and_eval_runner.TrainAndEvalRunner(
      FLAGS.num_examples_per_epoch // FLAGS.train_batch_size, train_steps,
      eval_steps, FLAGS.num_shards)
  train_input_fn = dataloader.InputReader(
      FLAGS.training_file_pattern,
      mode=tf.estimator.ModeKeys.TRAIN,
      use_fake_data=FLAGS.use_fake_data)
  eval_input_fn = functools.partial(
      dataloader.InputReader(
          FLAGS.validation_file_pattern,
          mode=tf.estimator.ModeKeys.PREDICT,
          distributed_eval=True),
      num_examples=eval_steps * FLAGS.eval_batch_size)
  eval_metric = coco_metric.EvaluationMetric(
      FLAGS.val_json_file, use_cpp_extension=True)

  def init_fn():
    if FLAGS.resnet_checkpoint:
      tf.train.init_from_checkpoint(FLAGS.resnet_checkpoint,
                                    {'resnet/': 'resnet50/'})

  runner.initialize(train_input_fn, eval_input_fn,
                    mask_rcnn_model.MaskRcnnModelFn(params),
                    FLAGS.train_batch_size, FLAGS.eval_batch_size,
                    FLAGS.input_partition_dims, init_fn, params=params)
  mlp_log.mlperf_print('init_stop', None)
  mlp_log.mlperf_print('run_start', None)

  def eval_init_fn(cur_step):
    """Executed before every eval."""
    steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
    cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch
    mlp_log.mlperf_print(
        'block_start',
        None,
        metadata={
            'first_epoch_num': cur_epoch,
            'epoch_count': 1
        })

  def eval_finish_fn(cur_step, eval_output, _):
    """Callback function that's executed after each eval."""
    if eval_steps == 0:
      return False
    # Concat eval_output as eval_output is a list from each host.
    for key in eval_output:
      eval_output[key] = np.concatenate(eval_output[key], axis=0)
    steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
    cur_epoch = 0 if steps_per_epoch == 0 else cur_step // steps_per_epoch
    mlp_log.mlperf_print(
        'block_stop',
        None,
        metadata={
            'first_epoch_num': cur_epoch,
            'epoch_count': 1
        })
    eval_multiprocess.eval_multiprocessing(eval_output, eval_metric,
                                           mask_rcnn_params.EVAL_WORKER_COUNT)

    mlp_log.mlperf_print(
        'eval_start', None, metadata={'epoch_num': cur_epoch + 1})
    _, eval_results = eval_metric.evaluate()
    mlp_log.mlperf_print(
        'eval_accuracy',
        {'BBOX': float(eval_results['AP']),
         'SEGM': float(eval_results['mask_AP'])},
        metadata={'epoch_num': cur_epoch + 1})
    mlp_log.mlperf_print(
        'eval_stop', None, metadata={'epoch_num': cur_epoch + 1})
    if (eval_results['AP'] >= mask_rcnn_params.BOX_EVAL_TARGET and
        eval_results['mask_AP'] >= mask_rcnn_params.MASK_EVAL_TARGET):
      mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'})
      return True
    return False

  def run_finish_fn(success):
    if not success:
      mlp_log.mlperf_print('run_stop', None, metadata={'status': 'abort'})

  runner.train_and_eval(eval_init_fn, eval_finish_fn, run_finish_fn)
def run_pretraining(hparams):
  """Run pretraining with given hyperparameters."""

  global masked_lm_accuracy
  global run_steps

  masked_lm_accuracy = 0
  run_steps = 0

  def eval_init_fn(cur_step):
    """Executed before every eval."""
    # While BERT pretraining does not have epochs,
    # to make the logging consistent with other mlperf models,
    # in all the mlp_log, epochs are steps, and examples are sequences.
    mlp_log.mlperf_print(
        "block_start",
        None,
        metadata={
            "first_epoch_num": cur_step + FLAGS.iterations_per_loop,
            "epoch_count": FLAGS.iterations_per_loop
        })

  def eval_finish_fn(cur_step, eval_output, summary_writer):
    """Executed after every eval."""
    global run_steps
    global masked_lm_accuracy
    cur_step_corrected = cur_step + FLAGS.iterations_per_loop
    run_steps = cur_step_corrected
    masked_lm_weighted_correct = eval_output["masked_lm_weighted_correct"]
    masked_lm_weighted_count = eval_output["masked_lm_weighted_count"]

    masked_lm_accuracy = np.sum(masked_lm_weighted_correct) / np.sum(
        masked_lm_weighted_count)
    # the eval_output may mix up the order of the two arrays
    # swap the order if it did got mix up
    if masked_lm_accuracy > 1:
      masked_lm_accuracy = 1 / masked_lm_accuracy

    if summary_writer:
      with tf.Graph().as_default():
        summary_writer.add_summary(
            tf.Summary(value=[
                tf.Summary.Value(tag="masked_lm_accuracy",
                                 simple_value=masked_lm_accuracy)
            ]), cur_step_corrected)

    mlp_log.mlperf_print(
        "block_stop",
        None,
        metadata={
            "first_epoch_num": cur_step_corrected,
        })
    # While BERT pretraining does not have epochs,
    # to make the logging consistent with other mlperf models,
    # in all the mlp_log, epochs are steps, and examples are sequences.
    mlp_log.mlperf_print(
        "eval_accuracy",
        float(masked_lm_accuracy),
        metadata={"epoch_num": cur_step_corrected})
    if (masked_lm_accuracy >= FLAGS.stop_threshold and
        cur_step_corrected >= FLAGS.iterations_per_loop * 6):
      mlp_log.mlperf_print("run_stop", None, metadata={"status": "success"})
      return True
    else:
      return False

  def run_finish_fn(success):
    if not success:
      mlp_log.mlperf_print("run_stop", None, metadata={"status": "abort"})
    mlp_log.mlperf_print("run_final", None)

  def init_fn():
    if FLAGS.init_checkpoint:
      tf.train.init_from_checkpoint(FLAGS.init_checkpoint, {
          "bert/": "bert/",
          "cls/": "cls/",
      })

  # Passing the hyperparameters
  if "learning_rate" in hparams:
    FLAGS.learning_rate = hparams.learning_rate
  if "lamb_weight_decay_rate" in hparams:
    FLAGS.lamb_weight_decay_rate = hparams.lamb_weight_decay_rate
  if "lamb_beta_1" in hparams:
    FLAGS.lamb_beta_1 = hparams.lamb_beta_1
  if "lamb_beta_2" in hparams:
    FLAGS.lamb_beta_2 = hparams.lamb_beta_2
  if "epsilon" in hparams:
    FLAGS.epsilon = hparams.epsilon
  if "num_warmup_steps" in hparams:
    FLAGS.num_warmup_steps = hparams.num_warmup_steps
  if "num_train_steps" in hparams:
    FLAGS.num_train_steps = hparams.num_train_steps

  # Input handling
  tf.logging.set_verbosity(tf.logging.INFO)
  if FLAGS.repeatable:
    tf.set_random_seed(123)

  if not FLAGS.do_train and not FLAGS.do_eval:
    raise ValueError("At least one of `do_train` or `do_eval` must be True.")

  train_input_files = []
  for input_pattern in FLAGS.input_file.split(","):
    train_input_files.extend(tf.gfile.Glob(input_pattern))

  eval_input_file = "/REDACTED/je-d/home/staging-REDACTED-gpu-dedicated/bert/eval_original_dataset/part-*"
  eval_input_files = []
  for input_pattern in eval_input_file.split(","):
    eval_input_files.extend(tf.gfile.Glob(input_pattern))

  tf.logging.info("*** Input Files ***")
  tf.logging.info("%s Files." % len(train_input_files))

  dataset_train = dataset_input.input_fn_builder(
      input_files=train_input_files,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=FLAGS.max_predictions_per_seq,
      is_training=True,
      num_cpu_threads=8)

  dataset_eval = dataset_input.input_fn_builder(
      input_files=eval_input_files,
      max_seq_length=FLAGS.max_seq_length,
      max_predictions_per_seq=FLAGS.max_predictions_per_seq,
      is_training=False,
      num_cpu_threads=8,
      num_eval_samples=FLAGS.num_eval_samples)

  # Create the low level runner
  low_level_runner = train_and_eval_runner.TrainAndEvalRunner(
      FLAGS.iterations_per_loop, FLAGS.stop_steps + 1, FLAGS.max_eval_steps,
      FLAGS.num_tpu_cores // FLAGS.num_partitions)

  mlp_log.mlperf_print("cache_clear", True)
  mlp_log.mlperf_print("init_start", None)
  mlp_log.mlperf_print("global_batch_size", FLAGS.train_batch_size)
  mlp_log.mlperf_print("opt_learning_rate_warmup_steps", FLAGS.num_warmup_steps)
  mlp_log.mlperf_print("num_warmup_steps", FLAGS.num_warmup_steps)
  mlp_log.mlperf_print("start_warmup_step", FLAGS.start_warmup_step)
  mlp_log.mlperf_print("max_sequence_length", FLAGS.max_seq_length)
  mlp_log.mlperf_print("opt_base_learning_rate", FLAGS.learning_rate)
  mlp_log.mlperf_print("opt_lamb_beta_1", FLAGS.lamb_beta_1)
  mlp_log.mlperf_print("opt_lamb_beta_2", FLAGS.lamb_beta_2)
  mlp_log.mlperf_print("opt_epsilon", 10 ** FLAGS.log_epsilon)
  mlp_log.mlperf_print("opt_learning_rate_training_steps",
                       FLAGS.num_train_steps)
  mlp_log.mlperf_print("opt_lamb_weight_decay_rate",
                       FLAGS.lamb_weight_decay_rate)
  mlp_log.mlperf_print("opt_lamb_learning_rate_decay_poly_power", 1)
  mlp_log.mlperf_print("opt_gradient_accumulation_steps", 0)
  mlp_log.mlperf_print("max_predictions_per_seq", FLAGS.max_predictions_per_seq)

  low_level_runner.initialize(
      dataset_train,
      dataset_eval,
      bert_model_fn,
      FLAGS.train_batch_size,
      FLAGS.eval_batch_size,
      input_partition_dims=None,
      init_fn=init_fn,
      train_has_labels=False,
      eval_has_labels=False,
      num_partitions=FLAGS.num_partitions)

  mlp_log.mlperf_print("init_stop", None)

  mlp_log.mlperf_print("run_start", None)

  # To make the logging consistent with other mlperf models,
  # in all the mlp_log, epochs are steps, and examples are sequences.
  mlp_log.mlperf_print("train_samples",
                       FLAGS.num_train_steps * FLAGS.train_batch_size)
  mlp_log.mlperf_print("eval_samples", FLAGS.num_eval_samples)
  low_level_runner.train_and_eval(eval_init_fn, eval_finish_fn, run_finish_fn)
  return masked_lm_accuracy, run_steps
예제 #3
0
def main(argv):
    del argv  # Unused.

    params = construct_run_config(FLAGS.iterations_per_loop)
    mlp_log.mlperf_print(key='cache_clear', value=True)
    mlp_log.mlperf_print(key='init_start', value=None)
    mlp_log.mlperf_print('global_batch_size', FLAGS.train_batch_size)
    mlp_log.mlperf_print('opt_base_learning_rate',
                         params['base_learning_rate'])
    mlp_log.mlperf_print(
        'opt_learning_rate_decay_boundary_epochs',
        [params['first_lr_drop_epoch'], params['second_lr_drop_epoch']])
    mlp_log.mlperf_print('opt_weight_decay', params['weight_decay'])
    mlp_log.mlperf_print(
        'model_bn_span', FLAGS.train_batch_size // FLAGS.num_shards *
        params['distributed_group_size'])
    mlp_log.mlperf_print('max_samples', ssd_constants.NUM_CROP_PASSES)
    mlp_log.mlperf_print('train_samples', FLAGS.num_examples_per_epoch)
    mlp_log.mlperf_print('eval_samples', FLAGS.eval_samples)

    params['batch_size'] = FLAGS.train_batch_size // FLAGS.num_shards
    input_partition_dims = FLAGS.input_partition_dims
    train_steps = FLAGS.num_epochs * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
    eval_steps = int(math.ceil(FLAGS.eval_samples / FLAGS.eval_batch_size))
    runner = train_and_eval_runner.TrainAndEvalRunner(
        FLAGS.iterations_per_loop, train_steps, eval_steps, FLAGS.num_shards)

    train_input_fn = dataloader.SSDInputReader(
        FLAGS.training_file_pattern,
        params['transpose_input'],
        is_training=True,
        use_fake_data=FLAGS.use_fake_data,
        params=params)
    eval_input_fn = dataloader.SSDInputReader(
        FLAGS.validation_file_pattern,
        is_training=False,
        use_fake_data=FLAGS.use_fake_data,
        distributed_eval=True,
        count=eval_steps * FLAGS.eval_batch_size,
        params=params)

    def init_fn():
        tf.train.init_from_checkpoint(
            params['resnet_checkpoint'], {
                'resnet/': 'resnet%s/' % ssd_constants.RESNET_DEPTH,
            })

    runner.initialize(train_input_fn, eval_input_fn,
                      functools.partial(ssd_model.ssd_model_fn,
                                        params), FLAGS.train_batch_size,
                      FLAGS.eval_batch_size, input_partition_dims, init_fn)
    mlp_log.mlperf_print('init_stop', None)
    mlp_log.mlperf_print('run_start', None)

    if FLAGS.run_cocoeval:
        # copybara:strip_begin
        q_in, q_out = REDACTEDprocess.get_user_data()
        processes = [
            REDACTEDprocess.Process(target=REDACTED_predict_post_processing)
            for _ in range(4)
        ]
        # copybara:strip_end_and_replace_begin
        # q_in = multiprocessing.Queue(maxsize=ssd_constants.QUEUE_SIZE)
        # q_out = multiprocessing.Queue(maxsize=ssd_constants.QUEUE_SIZE)
        # processes = [
        #     multiprocessing.Process(
        #         target=predict_post_processing, args=(q_in, q_out))
        #     for _ in range(self.num_multiprocessing_workers)
        # ]
        # copybara:replace_end
        for p in processes:
            p.start()

        def log_eval_results_fn():
            """Print out MLPerf log."""
            result = q_out.get()
            success = False
            while result[0] != _STOP:
                if not success:
                    steps_per_epoch = (FLAGS.num_examples_per_epoch //
                                       FLAGS.train_batch_size)
                    epoch = (result[0] +
                             FLAGS.iterations_per_loop) // steps_per_epoch
                    mlp_log.mlperf_print('eval_accuracy',
                                         result[1]['COCO/AP'],
                                         metadata={'epoch_num': epoch})
                    mlp_log.mlperf_print('eval_stop',
                                         None,
                                         metadata={'epoch_num': epoch})
                    if result[1]['COCO/AP'] > ssd_constants.EVAL_TARGET:
                        success = True
                        mlp_log.mlperf_print('run_stop',
                                             None,
                                             metadata={'status': 'success'})
                result = q_out.get()
            if not success:
                mlp_log.mlperf_print('run_stop',
                                     None,
                                     metadata={'status': 'abort'})

        log_eval_result_thread = threading.Thread(target=log_eval_results_fn)
        log_eval_result_thread.start()

    def eval_init_fn(cur_step):
        """Executed before every eval."""
        steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
        epoch = cur_step // steps_per_epoch
        mlp_log.mlperf_print('block_start',
                             None,
                             metadata={
                                 'first_epoch_num':
                                 epoch,
                                 'epoch_count':
                                 FLAGS.iterations_per_loop // steps_per_epoch
                             })
        mlp_log.mlperf_print('eval_start',
                             None,
                             metadata={
                                 'epoch_num':
                                 epoch +
                                 FLAGS.iterations_per_loop // steps_per_epoch
                             })

    def eval_finish_fn(cur_step, eval_output, _):
        steps_per_epoch = FLAGS.num_examples_per_epoch // FLAGS.train_batch_size
        epoch = cur_step // steps_per_epoch
        mlp_log.mlperf_print('block_stop',
                             None,
                             metadata={
                                 'first_epoch_num':
                                 epoch,
                                 'epoch_count':
                                 FLAGS.iterations_per_loop // steps_per_epoch
                             })
        if FLAGS.run_cocoeval:
            q_in.put((cur_step, eval_output['detections']))

    runner.train_and_eval(eval_init_fn, eval_finish_fn)

    if FLAGS.run_cocoeval:
        for _ in processes:
            q_in.put((_STOP, None))

        for p in processes:
            try:
                p.join(timeout=10)
            except Exception:  #  pylint: disable=broad-except
                pass

        q_out.put((_STOP, None))
        log_eval_result_thread.join()

        # Clear out all the queues to avoid deadlock.
        while not q_out.empty():
            q_out.get()
        while not q_in.empty():
            q_in.get()
예제 #4
0
def main(unused_argv):
    def eval_init_fn(cur_step):
        """Executed before every eval."""
        steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
        epoch = cur_step // steps_per_epoch
        mlp_log.mlperf_print('block_start',
                             None,
                             metadata={
                                 'first_epoch_num': epoch,
                                 'epoch_count': 4
                             })

    def eval_finish_fn(cur_step, eval_output, summary_writer):
        """Executed after every eval."""
        steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
        epoch = cur_step // steps_per_epoch
        eval_accuracy = float(np.sum(
            eval_output['total_correct'])) / FLAGS.num_eval_images

        if summary_writer:
            with tf.Graph().as_default():
                summary_writer.add_summary(
                    tf.Summary(value=[
                        tf.Summary.Value(tag='accuracy',
                                         simple_value=eval_accuracy)
                    ]), cur_step)
        mlp_log.mlperf_print('eval_accuracy',
                             eval_accuracy,
                             metadata={
                                 'epoch_num':
                                 epoch +
                                 FLAGS.iterations_per_loop // steps_per_epoch
                             })
        mlp_log.mlperf_print('block_stop',
                             None,
                             metadata={
                                 'first_epoch_num': epoch,
                                 'epoch_count': 4
                             })
        if eval_accuracy >= FLAGS.stop_threshold:
            mlp_log.mlperf_print('run_stop',
                                 None,
                                 metadata={'status': 'success'})
            return True
        else:
            return False

    def run_finish_fn(success):
        if not success:
            mlp_log.mlperf_print('run_stop',
                                 None,
                                 metadata={'status': 'abort'})
        mlp_log.mlperf_print('run_final', None)

    low_level_runner = train_and_eval_runner.TrainAndEvalRunner(
        FLAGS.iterations_per_loop, FLAGS.train_steps,
        int(math.ceil(FLAGS.num_eval_images / FLAGS.eval_batch_size)),
        FLAGS.num_replicas)

    mlp_log.mlperf_print('cache_clear', True)
    mlp_log.mlperf_print('init_start', None)
    mlp_log.mlperf_print('global_batch_size', FLAGS.train_batch_size)
    mlp_log.mlperf_print('lars_opt_weight_decay', FLAGS.weight_decay)
    mlp_log.mlperf_print('lars_opt_momentum', FLAGS.momentum)
    mlp_log.mlperf_print('submission_benchmark', 'resnet')
    mlp_log.mlperf_print('submission_division', 'closed')
    mlp_log.mlperf_print('submission_org', 'google')
    mlp_log.mlperf_print('submission_platform',
                         'tpu-v3-%d' % FLAGS.num_replicas)
    mlp_log.mlperf_print('submission_status', 'research')

    assert FLAGS.precision == 'bfloat16' or FLAGS.precision == 'float32', (
        'Invalid value for --precision flag; must be bfloat16 or float32.')
    input_dtype = tf.bfloat16 if FLAGS.precision == 'bfloat16' else tf.float32
    cache_decoded_image = True if FLAGS.num_replicas > 2048 else False
    imagenet_train, imagenet_eval = [
        imagenet_input.get_input_fn(  # pylint: disable=g-complex-comprehension
            FLAGS.data_dir,
            is_training,
            input_dtype,
            FLAGS.image_size,
            FLAGS.input_partition_dims is None,
            cache_decoded_image=cache_decoded_image)
        for is_training in [True, False]
    ]

    low_level_runner.initialize(imagenet_train, imagenet_eval, resnet_model_fn,
                                FLAGS.train_batch_size, FLAGS.eval_batch_size,
                                FLAGS.input_partition_dims)

    mlp_log.mlperf_print('train_samples', FLAGS.num_train_images)
    mlp_log.mlperf_print('eval_samples', FLAGS.num_eval_images)
    mlp_log.mlperf_print('init_stop', None)
    mlp_log.mlperf_print('run_start', None)
    low_level_runner.train_and_eval(eval_init_fn, eval_finish_fn,
                                    run_finish_fn)