Exemplo n.º 1
0
def train_fn(hparams):
    """Train function."""
    model_fn = make_model_fn(hparams)
    input_fn = make_input_fn(hparams, tf.contrib.learn.ModeKeys.TRAIN)

    log_step_count_steps = hparams.log_step_count_steps
    save_checkpoints_steps = hparams.save_checkpoints_steps
    # if hparams.use_dist_strategy:
    #   distribution_strategy = get_distribution_strategy(hparams.num_gpus)
    #   config = tf.estimator.RunConfig(
    #       train_distribute=distribution_strategy,
    #       log_step_count_steps=log_step_count_steps,
    #       keep_checkpoint_max=None,
    #       save_checkpoints_steps=save_checkpoints_steps)
    # else:
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.visible_device_list = str(hvd.local_rank())
    if hparams.use_autojit_xla:
        sess_config.graph_options.optimizer_options.global_jit_level = (
            tf.OptimizerOptions.ON_1)
    if not hparams.use_pintohost_optimizer:
        sess_config.graph_options.rewrite_options.pin_to_host_optimization = (
            rewriter_config_pb2.RewriterConfig.OFF)
    config = tf.estimator.RunConfig(
        log_step_count_steps=log_step_count_steps,
        session_config=sess_config,
        keep_checkpoint_max=None,
        save_checkpoints_steps=save_checkpoints_steps)

    misc_utils.print_out("sess master is %s" % config.master)
    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=hparams.output_dir,
                                       config=config)

    benchmark_hook = BenchmarkHook(hparams.batch_size,
                                   hparams.warmup_steps + 5)
    train_hooks = [hvd.BroadcastGlobalVariablesHook(0), benchmark_hook]
    if hparams.profile:
        train_hooks.append(
            tf.train.ProfilerHook(output_dir=hparams.output_dir,
                                  save_steps=hparams.profile_save_steps,
                                  show_dataflow=True,
                                  show_memory=True))

    max_steps = hparams.debug_num_train_steps
    global_step = estimator.train(
        input_fn=input_fn,
        max_steps=max_steps,
        hooks=train_hooks,
    ).get_variable_value(tf.GraphKeys.GLOBAL_STEP)

    sentences_sec, latencies = benchmark_hook.get_average_speed_and_latencies()
    return global_step, sentences_sec, latencies
Exemplo n.º 2
0
def train_fn(hparams):
    """Train function."""
    model_fn = make_model_fn(hparams)
    input_fn = make_input_fn(hparams, tf.contrib.learn.ModeKeys.TRAIN)

    log_step_count_steps = hparams.log_step_count_steps
    save_checkpoints_steps = hparams.save_checkpoints_steps
    if hparams.use_dist_strategy:
        distribution_strategy = get_distribution_strategy(hparams.num_gpus)
        config = tf.estimator.RunConfig(
            train_distribute=distribution_strategy,
            log_step_count_steps=log_step_count_steps,
            keep_checkpoint_max=None,
            save_checkpoints_steps=save_checkpoints_steps)
    else:
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        if hparams.use_autojit_xla:
            sess_config.graph_options.optimizer_options.global_jit_level = (
                tf.OptimizerOptions.ON_1)
        if not hparams.use_pintohost_optimizer:
            sess_config.graph_options.rewrite_options.pin_to_host_optimization = (
                rewriter_config_pb2.RewriterConfig.OFF)
        config = tf.estimator.RunConfig(
            log_step_count_steps=log_step_count_steps,
            session_config=sess_config,
            keep_checkpoint_max=None,
            save_checkpoints_steps=save_checkpoints_steps)

    misc_utils.print_out("sess master is %s" % config.master)
    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=hparams.output_dir,
                                       config=config)

    benchmark_hook = BenchmarkHook(hparams.batch_size,
                                   hparams.warmup_steps + 5)
    train_hooks = [benchmark_hook]
    if hparams.profile:
        train_hooks.append(
            tf.train.ProfilerHook(output_dir=hparams.output_dir,
                                  save_steps=hparams.profile_save_steps,
                                  show_dataflow=True,
                                  show_memory=True))

    max_steps = hparams.debug_num_train_steps
    estimator.train(
        input_fn=input_fn,
        max_steps=max_steps,
        hooks=train_hooks,
    )

    return benchmark_hook.get_average_speed_and_latencies()
Exemplo n.º 3
0
def get_metrics(hparams, model_fn, ckpt=None, only_translate=False):
    """Run inference and compute metrics."""
    pred_estimator = tf.estimator.Estimator(model_fn=model_fn,
                                            model_dir=hparams.output_dir)

    benchmark_hook = BenchmarkHook(hparams.infer_batch_size)

    predictions = pred_estimator.predict(make_input_fn(
        hparams, tf.contrib.learn.ModeKeys.INFER),
                                         checkpoint_path=ckpt,
                                         hooks=[benchmark_hook])
    translations = []
    output_tokens = []
    beam_id = 0
    for prediction in predictions:
        # get the top translation.
        if beam_id == 0:
            for sent_id in range(hparams.infer_batch_size):
                if sent_id >= prediction["predictions"].shape[0]:
                    break
                trans, output_length = nmt_utils.get_translation(
                    prediction["predictions"],
                    sent_id=sent_id,
                    tgt_eos=hparams.eos,
                    subword_option=hparams.subword_option)
                translations.append(trans)
                output_tokens.append(output_length)
        beam_id += 1
        if beam_id == hparams.beam_width:
            beam_id = 0

    if only_translate:
        trans_file = hparams.translate_file + '.trans.tok'
    else:
        trans_file = os.path.join(
            hparams.output_dir, "newstest2014_out_{}.tok.de".format(
                pred_estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP)))
    trans_dir = os.path.dirname(trans_file)
    if not tf.gfile.Exists(trans_dir):
        tf.gfile.MakeDirs(trans_dir)
    tf.logging.info("Writing to file %s" % trans_file)
    with codecs.getwriter("utf-8")(tf.gfile.GFile(trans_file,
                                                  mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        for translation in translations:
            trans_f.write((translation + b"\n").decode("utf-8"))

    if only_translate:
        return None, benchmark_hook.get_average_speed_and_latencies(), sum(
            output_tokens)

    # Evaluation
    output_dir = os.path.join(pred_estimator.model_dir, "eval")
    tf.gfile.MakeDirs(output_dir)
    summary_writer = tf.summary.FileWriter(output_dir)

    ref_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
    # Hardcoded.
    metric = "bleu"
    score = get_sacrebleu(trans_file, hparams.detokenizer_file)

    misc_utils.print_out("bleu is %.5f" % score)
    with tf.Graph().as_default():
        summaries = []
        summaries.append(tf.Summary.Value(tag=metric, simple_value=score))
    tf_summary = tf.Summary(value=list(summaries))
    summary_writer.add_summary(
        tf_summary,
        pred_estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP))

    summary_writer.close()
    return score, benchmark_hook.get_average_speed_and_latencies(), sum(
        output_tokens)