Ejemplo n.º 1
0
    def __init__(self, iterations, hparams, per_host_v1=False):
        tf.logging.info("TrainLowLevelRunner: constructor")

        self.feature_structure = {}
        self.loss = None
        self.infeed_queue = []
        self.enqueue_ops = []
        self.dataset_initializer = []
        self.is_local = ((hparams.master == "") and (hparams.tpu_name is None))
        self.per_host_v1 = per_host_v1
        self.iterations = iterations
        self.sess = None
        self.graph = tf.Graph()
        self.hparams = hparams
        with self.graph.as_default():
            self.tpu_init = [tpu.initialize_system()]
            self.tpu_shutdown = tpu.shutdown_system()

        self.resolver = get_resolver(hparams)
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        isolate_session_state=True)
        if self.hparams.tpu_name is None:
            master = self.hparams.master
        else:
            cluster_spec = self.resolver.cluster_spec()
            tf.logging.info(cluster_spec)
            if cluster_spec:
                session_config.cluster_def.CopyFrom(
                    cluster_spec.as_cluster_def())
            master = self.resolver.get_master()
        self.sess = tf.Session(master, graph=self.graph, config=session_config)
        self.sess.run(self.tpu_init)

        self.hooks = lottery.hooks_from_flags(hparams.values())
def train_and_eval_fn(hparams):
    """Train and evaluation function."""
    hooks = lottery.hooks_from_flags(hparams.values())

    mlperf_log.gnmt_print(key=mlperf_log.RUN_START)
    hparams.tgt_sos_id, hparams.tgt_eos_id = 1, 2
    model_fn = make_model_fn(hparams, hooks)
    input_fn = make_input_fn(hparams, tf.contrib.learn.ModeKeys.TRAIN)
    run_config = _get_tpu_run_config(hparams, False)
    estimator = tf.contrib.tpu.TPUEstimator(
        model_fn=model_fn,
        config=run_config,
        use_tpu=hparams.use_tpu,
        train_batch_size=hparams.batch_size,
        eval_batch_size=hparams.batch_size,
        predict_batch_size=hparams.infer_batch_size,
    )

    score = 0.0
    mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP)
    mlperf_log.gnmt_print(key=mlperf_log.EVAL_TARGET,
                          value=hparams.target_bleu)

    for i in range(hparams.max_train_epochs):
        mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=i)
        tf.logging.info("Start training epoch %d", i)
        mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE,
                              value=hparams.num_examples_per_epoch)
        steps_per_epoch = int(hparams.num_examples_per_epoch /
                              hparams.batch_size)
        max_steps = steps_per_epoch * (i + 1)
        estimator.train(input_fn=input_fn, max_steps=max_steps)
        mlperf_log.gnmt_print(key=mlperf_log.TRAIN_CHECKPOINT,
                              value=("Under " + hparams.out_dir))
        tf.logging.info("End training epoch %d", i)

        mlperf_log.gnmt_print(key=mlperf_log.EVAL_START)
        score = get_metric_from_estimator(hparams, estimator)
        tf.logging.info("Score after epoch %d: %f", i, score)
        mlperf_log.gnmt_print(key=mlperf_log.EVAL_ACCURACY,
                              value={
                                  "value": score,
                                  "epoch": i
                              })
        mlperf_log.gnmt_print(key=mlperf_log.EVAL_STOP, value=i)

    mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": False})
    return score
Ejemplo n.º 3
0
def vgg_main(
    flags_obj, model_function, input_function, dataset_name, shape=None):
  """Shared main loop for VGG Models.

  Args:
    flags_obj: An object containing parsed flags. See define_vgg_flags()
      for details.
    model_function: the function that instantiates the Model and builds the
      ops for train/eval. This will be passed directly into the estimator.
    input_function: the function that processes the dataset and returns a
      dataset that the estimator can train on. This will be wrapped with
      all the relevant flags for running and passed to estimator.
    dataset_name: the name of the dataset for training and evaluation. This is
      used for logging purpose.
    shape: list of ints representing the shape of the images used for training.
      This is only used if flags_obj.export_dir is passed.

  Returns:
    Dict of results of the run.
  """

  model_helpers.apply_clean(flags.FLAGS)

  # Ensures flag override logic is only executed if explicitly triggered.
  if flags_obj.tf_gpu_thread_mode:
    override_flags_and_set_envars_for_gpu_thread_pool(flags_obj)

  # Creates session config. allow_soft_placement = True, is required for
  # multi-GPU and is not harmful for other modes.
  session_config = tf.ConfigProto(
      inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
      intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
      allow_soft_placement=True)

  distribution_strategy = distribution_utils.get_distribution_strategy(
      flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

  # Creates a `RunConfig` that checkpoints every 24 hours which essentially
  # results in checkpoints determined only by `epochs_between_evals`.
  run_config = tf.estimator.RunConfig(
      train_distribute=distribution_strategy,
      session_config=session_config,
      save_checkpoints_secs=60*60*24)

  # Initializes model with all but the dense layer from pretrained VGG.
  if flags_obj.pretrained_model_checkpoint_path is not None:
    warm_start_settings = tf.estimator.WarmStartSettings(
        flags_obj.pretrained_model_checkpoint_path,
        vars_to_warm_start='^(?!.*dense)')
  else:
    warm_start_settings = None

  classifier = tf.estimator.Estimator(
      model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
      warm_start_from=warm_start_settings, params={
          'vgg_size': flags_obj.vgg_size,
          'data_format': flags_obj.data_format,
          'batch_size': flags_obj.batch_size,
          'loss_scale': flags_core.get_loss_scale(flags_obj),
          'dtype': flags_core.get_tf_dtype(flags_obj),
          'fine_tune': flags_obj.fine_tune
      })

  run_params = {
      'batch_size': flags_obj.batch_size,
      'dtype': flags_core.get_tf_dtype(flags_obj),
      'vgg_size': flags_obj.vgg_size,
      'synthetic_data': flags_obj.use_synthetic_data,
      'train_epochs': flags_obj.train_epochs,
  }
  if flags_obj.use_synthetic_data:
    dataset_name = dataset_name + '-synthetic'

  benchmark_logger = logger.get_benchmark_logger()
  benchmark_logger.log_run_info('vgg', dataset_name, run_params,
                                test_id=flags_obj.benchmark_test_id)

  train_hooks = hooks_helper.get_train_hooks(
      flags_obj.hooks,
      model_dir=flags_obj.model_dir,
      batch_size=flags_obj.batch_size)

  train_hooks = list(train_hooks) + lottery.hooks_from_flags(flags_obj.flag_values_dict())

  def input_fn_train(num_epochs):
    return input_function(
        is_training=True,
        data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=num_epochs,
        dtype=flags_core.get_tf_dtype(flags_obj),
        datasets_num_private_threads=flags_obj.datasets_num_private_threads,
        num_parallel_batches=flags_obj.datasets_num_parallel_batches)

  def input_fn_eval():
    return input_function(
        is_training=False,
        data_dir=flags_obj.data_dir,
        batch_size=distribution_utils.per_device_batch_size(
            flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
        num_epochs=1,
        dtype=flags_core.get_tf_dtype(flags_obj))

  if flags_obj.lth_generate_predictions:
    ckpt = tf.train.latest_checkpoint(flags_obj.model_dir)

    if flags_obj.lth_no_pruning:
      m_hooks = []
    else:
      m_hooks = lottery.hooks_from_flags(flags_obj.flag_values_dict())

    eval_results = classifier.predict(
        input_fn=input_fn_eval,
        checkpoint_path=ckpt,
        hooks=m_hooks,
    )

    assert flags_obj.lth_prediction_result_dir
    with tf.gfile.Open(os.path.join(flags_obj.data_dir, 'test_batch.bin'), 'rb') as f:
      labels = list(f.read()[::32*32*3+1])

    eval_results = list(eval_results)
    if not tf.gfile.Exists(flags_obj.lth_prediction_result_dir):
      tf.gfile.MakeDirs(flags_obj.lth_prediction_result_dir)
    with tf.gfile.Open(os.path.join(flags_obj.lth_prediction_result_dir, 'predictions'), 'wb') as f:
      for label, res in zip(labels, eval_results):
        res['label'] = label
      pickle.dump(eval_results, f)
    return

  try:
    cpr = tf.train.NewCheckpointReader(tf.train.latest_checkpoint(flags_obj.model_dir))
    current_step = cpr.get_tensor('global_step')
  except:
    current_step = 0

  while current_step < flags_obj.max_train_steps:
    next_checkpoint = min(current_step + 10000, flags_obj.max_train_steps)
    classifier.train(input_fn=lambda: input_fn_train(1000), hooks=train_hooks, max_steps=next_checkpoint)
    current_step = next_checkpoint
    tf.logging.info('Starting to evaluate.')
    eval_results = classifier.evaluate(input_fn=input_fn_eval)
    benchmark_logger.log_evaluation_result(eval_results)

  if flags_obj.export_dir is not None:
    # Exports a saved model for the given classifier.
    export_dtype = flags_core.get_tf_dtype(flags_obj)
    if flags_obj.image_bytes_as_serving_input:
      input_receiver_fn = functools.partial(
          image_bytes_serving_input_fn, shape, dtype=export_dtype)
    else:
      input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
          shape, batch_size=flags_obj.batch_size, dtype=export_dtype)
    classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn,
                                 strip_default_attrs=True)
def main(unused_argv):
    params = hyperparameters.get_hyperparameters(FLAGS.default_hparams_file,
                                                 FLAGS.hparams_file, FLAGS,
                                                 FLAGS.hparams)
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu if (FLAGS.tpu or params['use_tpu']) else '',
        zone=FLAGS.tpu_zone,
        project=FLAGS.gcp_project)

    if params['use_async_checkpointing']:
        save_checkpoints_steps = None
    else:
        save_checkpoints_steps = max(5000, params['iterations_per_loop'])
    config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.model_dir,
        save_checkpoints_steps=save_checkpoints_steps,
        log_step_count_steps=FLAGS.log_step_count_steps,
        session_config=tf.ConfigProto(
            graph_options=tf.GraphOptions(
                rewrite_options=rewriter_config_pb2.RewriterConfig(
                    disable_meta_optimizer=True))),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=params['iterations_per_loop'],
            num_shards=params['num_cores'],
            per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig
            .PER_HOST_V2))  # pylint: disable=line-too-long

    resnet_classifier = tf.contrib.tpu.TPUEstimator(
        use_tpu=params['use_tpu'],
        model_fn=resnet_model_fn,
        config=config,
        params=params,
        train_batch_size=params['train_batch_size'],
        eval_batch_size=params['eval_batch_size'],
        predict_batch_size=params['eval_batch_size'],
        export_to_tpu=FLAGS.export_to_tpu)

    assert (params['precision'] == 'bfloat16' or params['precision']
            == 'float32'), ('Invalid value for precision parameter; '
                            'must be bfloat16 or float32.')
    tf.logging.info('Precision: %s', params['precision'])
    use_bfloat16 = params['precision'] == 'bfloat16'

    # Input pipelines are slightly different (with regards to shuffling and
    # preprocessing) between training and evaluation.
    if FLAGS.bigtable_instance:
        tf.logging.info('Using Bigtable dataset, table %s',
                        FLAGS.bigtable_table)
        select_train, select_eval = _select_tables_from_flags()
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetBigtableInput(
                is_training=is_training,
                use_bfloat16=use_bfloat16,
                transpose_input=params['transpose_input'],
                selection=selection)
            for (is_training,
                 selection) in [(True, select_train), (False, select_eval)]
        ]
    else:
        if FLAGS.data_dir == FAKE_DATA_DIR:
            tf.logging.info('Using fake dataset.')
        else:
            tf.logging.info('Using dataset: %s', FLAGS.data_dir)
        imagenet_train, imagenet_eval = [
            imagenet_input.ImageNetInput(
                is_training=is_training,
                data_dir=FLAGS.data_dir,
                transpose_input=params['transpose_input'],
                cache=params['use_cache'] and is_training,
                image_size=params['image_size'],
                num_parallel_calls=params['num_parallel_calls'],
                use_bfloat16=use_bfloat16) for is_training in [True, False]
        ]

    steps_per_epoch = params['num_train_images'] // params['train_batch_size']
    eval_steps = params['num_eval_images'] // params['eval_batch_size']
    lottery_hooks = lottery.hooks_from_flags(params)

    if FLAGS.mode == 'eval':

        # Run evaluation when there's a new checkpoint
        for ckpt in evaluation.checkpoints_iterator(
                FLAGS.model_dir, timeout=FLAGS.eval_timeout):
            tf.logging.info('Starting to evaluate.')
            try:
                start_timestamp = time.time(
                )  # This time will include compilation time
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=eval_steps,
                    checkpoint_path=ckpt)
                elapsed_time = int(time.time() - start_timestamp)
                tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                                eval_results, elapsed_time)

                # Terminate eval job when final checkpoint is reached
                current_step = int(os.path.basename(ckpt).split('-')[1])
                if current_step >= params['train_steps']:
                    tf.logging.info(
                        'Evaluation finished after training step %d',
                        current_step)
                    break

            except tf.errors.NotFoundError:
                # Since the coordinator is on a different job than the TPU worker,
                # sometimes the TPU worker does not finish initializing until long after
                # the CPU job tells it to start evaluating. In this case, the checkpoint
                # file could have been deleted already.
                tf.logging.info(
                    'Checkpoint %s no longer exists, skipping checkpoint',
                    ckpt)
    elif FLAGS.mode == 'predict':
        tf.logging.info('Starting to evaluate.')
        ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)

        if FLAGS.lth_no_pruning:
            m_hooks = []
        else:
            m_hooks = lottery_hooks

        start_timestamp = time.time(
        )  # This time will include compilation time
        eval_results = resnet_classifier.predict(
            input_fn=imagenet_eval.input_fn,
            checkpoint_path=ckpt,
            hooks=m_hooks,
        )

        assert FLAGS.lth_prediction_result_dir
        with tf.gfile.Open(
                os.path.join(FLAGS.lth_prediction_result_dir, 'predictions'),
                'wb') as f:
            pickle.dump(list(eval_results), f)

    else:  # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
        current_step = estimator._load_global_step_from_checkpoint_dir(
            FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
        steps_per_epoch = params['num_train_images'] // params[
            'train_batch_size']
        tf.logging.info(
            'Training for %d steps (%.2f epochs in total). Current'
            ' step %d.', params['train_steps'],
            params['train_steps'] / steps_per_epoch, current_step)

        start_timestamp = time.time(
        )  # This time will include compilation time

        if FLAGS.mode == 'train':
            hooks = []
            if params['use_async_checkpointing']:
                hooks.append(
                    async_checkpoint.AsyncCheckpointSaverHook(
                        checkpoint_dir=FLAGS.model_dir,
                        save_steps=max(5000, params['iterations_per_loop'])))
            if FLAGS.profile_every_n_steps > 0:
                hooks.append(
                    tpu_profiler_hook.TPUProfilerHook(
                        save_steps=FLAGS.profile_every_n_steps,
                        output_dir=FLAGS.model_dir,
                        tpu=FLAGS.tpu))
            resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                    max_steps=int(params['train_steps']),
                                    hooks=hooks + lottery_hooks)

        else:
            assert FLAGS.mode == 'train_and_eval'
            while current_step < params['train_steps']:
                # Train for up to steps_per_eval number of steps.
                # At the end of training, a checkpoint will be written to --model_dir.
                next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                                      params['train_steps'])
                resnet_classifier.train(input_fn=imagenet_train.input_fn,
                                        max_steps=int(next_checkpoint),
                                        hooks=lottery_hooks)
                current_step = next_checkpoint

                tf.logging.info(
                    'Finished training up to step %d. Elapsed seconds %d.',
                    next_checkpoint, int(time.time() - start_timestamp))

                # Evaluate the model on the most recent model in --model_dir.
                # Since evaluation happens in batches of --eval_batch_size, some images
                # may be excluded modulo the batch size. As long as the batch size is
                # consistent, the evaluated images are also consistent.
                tf.logging.info('Starting to evaluate.')
                eval_results = resnet_classifier.evaluate(
                    input_fn=imagenet_eval.input_fn,
                    steps=params['num_eval_images'] //
                    params['eval_batch_size'])
                tf.logging.info('Eval results at step %d: %s', next_checkpoint,
                                eval_results)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info(
                'Finished training up to step %d. Elapsed seconds %d.',
                params['train_steps'], elapsed_time)

        if FLAGS.export_dir is not None:
            # The guide to serve a exported TensorFlow model is at:
            #    https://www.tensorflow.org/serving/serving_basic
            tf.logging.info('Starting to export model.')
            export_path = resnet_classifier.export_saved_model(
                export_dir_base=FLAGS.export_dir,
                serving_input_receiver_fn=imagenet_input.image_serving_input_fn
            )
            if FLAGS.add_warmup_requests:
                inference_warmup.write_warmup_requests(
                    export_path,
                    FLAGS.model_name,
                    params['image_size'],
                    batch_sizes=FLAGS.inference_batch_sizes,
                    image_format='JPEG')