def train_fn(hparams, num_workers):
    """Copy of train function from estimator.py."""
    # TODO: Merge improvements into the original.
    # pylint: disable=protected-access
    hparams.tgt_sos_id, hparams.tgt_eos_id = nmt_estimator._get_tgt_sos_eos_id(
        hparams)
    model_fn = nmt_estimator.make_model_fn(hparams)

    def print_log():
        mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP)
        mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=0)
        mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE,
                              value=hparams.num_examples_per_epoch)

    if hparams.use_tpu_low_level_api:
        runner = create_train_runner(hparams, num_workers)
        mlperf_log.gnmt_print(key=mlperf_log.RUN_START)
        input_fn = DistributedPipeline(hparams, num_workers)
        runner.initialize(input_fn, {})
        runner.build_model(model_fn, {})
        print_log()
        runner.train(0, hparams.num_train_steps)
        return 0.0

    # cluster = tf.contrib.cluster_resolver.TPUClusterResolver(hparams.tpu_name)
    # cluster_spec = cluster.cluster_spec()
    # print('cluster_spec: %s' % cluster_spec)
    # num_workers = cluster_spec.num_tasks('tpu_worker')
    # print('num_workers: %s' % num_workers)

    pipeline = DistributedPipeline(hparams, num_workers)
    print_log()

    if hparams.use_tpu:
        run_config = nmt_estimator._get_tpu_run_config(hparams, True)
        estimator = tf.contrib.tpu.TPUEstimator(
            model_fn=model_fn,
            config=run_config,
            use_tpu=hparams.use_tpu,
            train_batch_size=hparams.batch_size,
            eval_batch_size=hparams.batch_size,
            predict_batch_size=hparams.infer_batch_size)
    else:
        raise ValueError("Distributed input pipeline only supported on TPUs.")

    hooks = [pipeline]
    if hparams.use_async_checkpoint:
        hooks.append(
            async_checkpoint.AsyncCheckpointSaverHook(
                checkpoint_dir=hparams.out_dir,
                save_steps=int(hparams.num_examples_per_epoch /
                               hparams.batch_size)))

    estimator.train(input_fn=pipeline,
                    max_steps=hparams.num_train_steps,
                    hooks=hooks)
    # Return value is not used
    return 0.0
Exemplo n.º 2
0
def train_and_eval_with_low_level_api(hparams, num_workers):
    """Train and evaluation function."""
    # pylint: disable=protected-access
    hparams.tgt_sos_id, hparams.tgt_eos_id = 1, 2
    model_fn = nmt_estimator.make_model_fn(hparams)
    train_runner = create_train_runner(hparams, num_workers)
    eval_runner = nmt_estimator.create_eval_runner(hparams, model_fn)
    mlperf_log.gnmt_print(key=mlperf_log.RUN_START)
    train_input_fn = DistributedPipeline(hparams, num_workers)
    train_runner.initialize(train_input_fn, {})
    train_runner.build_model(model_fn, {})

    eval_input_fn = nmt_estimator.make_input_fn(
        hparams, tf.contrib.learn.ModeKeys.INFER)
    params = {
        "infer_batch_size": int(hparams.infer_batch_size / hparams.num_shards)
    }
    eval_runner.initialize(eval_input_fn, params)
    eval_runner.build_model(model_fn, params)

    score = 0.0
    mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP)
    mlperf_log.gnmt_print(key=mlperf_log.EVAL_TARGET,
                          value=hparams.target_bleu)
    current_step = 0
    for i in range(hparams.max_train_epochs):
        mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=i)
        tf.logging.info("Start training epoch %d", i)
        mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE,
                              value=hparams.num_examples_per_epoch)

        steps_per_epoch = int(hparams.num_examples_per_epoch /
                              hparams.batch_size)
        train_runner.train(current_step, steps_per_epoch // 2)
        current_step += steps_per_epoch // 2
        train_runner.train(current_step, steps_per_epoch // 2)
        current_step += steps_per_epoch // 2

        mlperf_log.gnmt_print(key=mlperf_log.TRAIN_CHECKPOINT,
                              value=("Under " + hparams.out_dir))
        tf.logging.info("End training epoch %d", i)
        mlperf_log.gnmt_print(key=mlperf_log.EVAL_START)
        predictions = list(eval_runner.predict())
        score = nmt_estimator.get_metric(hparams, predictions, current_step)
        tf.logging.info("Score after epoch %d: %f", i, score)
        mlperf_log.gnmt_print(key=mlperf_log.EVAL_ACCURACY,
                              value={
                                  "value": score,
                                  "epoch": i
                              })
        mlperf_log.gnmt_print(key=mlperf_log.EVAL_STOP, value=i)
        # if score >= hparams.target_bleu:
        #   mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": True})
        #   return score

    mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": False})
    return score
def train_and_eval_fn(hparams, num_workers):
    """Train and evaluation function."""
    # pylint: disable=protected-access
    mlperf_log.gnmt_print(key=mlperf_log.RUN_START)
    hparams.tgt_sos_id, hparams.tgt_eos_id = 1, 2
    model_fn = nmt_estimator.make_model_fn(hparams)
    pipeline = DistributedPipeline(hparams, num_workers)

    run_config = nmt_estimator._get_tpu_run_config(hparams)
    estimator = tf.contrib.tpu.TPUEstimator(
        model_fn=model_fn,
        config=run_config,
        use_tpu=hparams.use_tpu,
        train_batch_size=hparams.batch_size,
        eval_batch_size=hparams.batch_size,
        predict_batch_size=hparams.infer_batch_size)

    score = 0.0
    mlperf_log.gnmt_print(key=mlperf_log.TRAIN_LOOP)
    mlperf_log.gnmt_print(key=mlperf_log.EVAL_TARGET,
                          value=hparams.target_bleu)
    for i in range(hparams.max_train_epochs):
        mlperf_log.gnmt_print(key=mlperf_log.TRAIN_EPOCH, value=i)
        tf.logging.info("Start training epoch %d", i)
        mlperf_log.gnmt_print(key=mlperf_log.INPUT_SIZE,
                              value=hparams.num_examples_per_epoch)
        steps_per_epoch = int(hparams.num_examples_per_epoch /
                              hparams.batch_size)
        max_steps = steps_per_epoch * (i + 1)
        estimator.train(input_fn=pipeline,
                        max_steps=max_steps,
                        hooks=[pipeline])
        mlperf_log.gnmt_print(key=mlperf_log.TRAIN_CHECKPOINT,
                              value=("Under " + hparams.out_dir))
        tf.logging.info("End training epoch %d", i)
        mlperf_log.gnmt_print(key=mlperf_log.EVAL_START)
        score = nmt_estimator.get_metric_from_estimator(hparams, estimator)
        tf.logging.info("Score after epoch %d: %f", i, score)
        mlperf_log.gnmt_print(key=mlperf_log.EVAL_ACCURACY,
                              value={
                                  "value": score,
                                  "epoch": i
                              })
        mlperf_log.gnmt_print(key=mlperf_log.EVAL_STOP, value=i)
        if score >= hparams.target_bleu:
            mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": True})
            return score

    mlperf_log.gnmt_print(mlperf_log.RUN_STOP, {"success": False})
    return score