def testTrain(self):
        """Test the training loop is functional with basic hparams."""
        nmt_parser = argparse.ArgumentParser()
        nmt.add_arguments(nmt_parser)
        flags, _ = nmt_parser.parse_known_args()
        update_flags(flags, "nmt_train_test")
        default_hparams = nmt.create_hparams(flags)

        nmt.run_main(flags, default_hparams, estimator.train_and_eval_fn)
Esempio n. 2
0
    def testTrain(self):
        """Test the training loop is functional with basic hparams."""
        nmt_parser = argparse.ArgumentParser()
        nmt.add_arguments(nmt_parser)
        FLAGS, unparsed = nmt_parser.parse_known_args()

        _update_flags(FLAGS, "nmt_train_test")

        default_hparams = nmt.create_hparams(FLAGS)

        train_fn = train.train
        nmt.run_main(FLAGS, default_hparams, train_fn, None)
Esempio n. 3
0
    def testInference(self):
        """Test inference is function with basic hparams."""
        nmt_parser = argparse.ArgumentParser()
        nmt.add_arguments(nmt_parser)
        FLAGS, unparsed = nmt_parser.parse_known_args()

        _update_flags(FLAGS, "nmt_train_infer")

        # Train one step so we have a checkpoint.
        FLAGS.num_train_steps = 1
        default_hparams = nmt.create_hparams(FLAGS)
        train_fn = train.train
        nmt.run_main(FLAGS, default_hparams, train_fn, None)

        # Update FLAGS for inference.
        FLAGS.inference_input_file = ("nmt/testdata/" "iwslt15.tst2013.100.en")
        FLAGS.inference_output_file = os.path.join(FLAGS.out_dir, "output")
        FLAGS.inference_ref_file = ("nmt/testdata/" "iwslt15.tst2013.100.vi")

        default_hparams = nmt.create_hparams(FLAGS)

        inference_fn = inference.inference
        nmt.run_main(FLAGS, default_hparams, None, inference_fn)
  def testTrainInputFn(self):
    nmt_parser = argparse.ArgumentParser()
    nmt.add_arguments(nmt_parser)
    flags, _ = nmt_parser.parse_known_args()
    update_flags(flags, "input_fn_test")
    default_hparams = nmt.create_hparams(flags)
    hparams = nmt.extend_hparams(default_hparams)

    with self.test_session() as sess:
      input_fn = make_input_fn(hparams, tf.contrib.learn.ModeKeys.TRAIN)
      outputs = input_fn({})
      sess.run(tf.tables_initializer())
      iterator = outputs.make_initializable_iterator()
      sess.run(iterator.initializer)
      features = sess.run(iterator.get_next())
      tf.logging.info("source: %s", features["source"])
      tf.logging.info("target_input: %s", features["target_input"])
      tf.logging.info("target_output: %s", features["target_output"])
      tf.logging.info("source_sequence_length: %s",
                      features["source_sequence_length"])
      tf.logging.info("target_sequence_length: %s",
                      features["target_sequence_length"])
def main(_):
    default_hparams = nmt.create_hparams(FLAGS)
    ## Train / Decode
    out_dir = FLAGS.out_dir
    if not tf.gfile.Exists(out_dir): tf.gfile.MakeDirs(out_dir)

    # Load hparams.
    hparams = nmt.create_or_load_hparams(out_dir,
                                         default_hparams,
                                         FLAGS.hparams_path,
                                         save_hparams=False)

    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    train_model =\
        model_helper.create_train_model(model_creator, hparams, scope=None)

    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=1,
        num_inter_threads=36)

    def run(train_sess, num_workers, worker_id, num_replicas_per_worker):

        # Random
        random_seed = FLAGS.random_seed
        if random_seed is not None and random_seed > 0:
            utils.print_out("# Set random seed to %d" % random_seed)
            random.seed(random_seed + worker_id)
            np.random.seed(random_seed + worker_id)

        # Log and output files
        log_file = os.path.join(out_dir, "log_%d" % time.time())
        log_f = tf.gfile.GFile(log_file, mode="a")
        utils.print_out("# log_file=%s" % log_file, log_f)

        global_step = train_sess.run(train_model.model.global_step)[0]
        last_stats_step = global_step

        # This is the training loop.
        stats, info, start_train_time = before_train(train_model, train_sess,
                                                     global_step, hparams,
                                                     log_f,
                                                     num_replicas_per_worker)

        epoch_steps = FLAGS.epoch_size / (FLAGS.batch_size * num_workers *
                                          num_replicas_per_worker)

        for i in range(FLAGS.max_steps):
            ### Run a step ###
            start_time = time.time()
            if hparams.epoch_step != 0 and hparams.epoch_step % epoch_steps == 0:
                hparams.epoch_step = 0
                skip_count = train_model.skip_count_placeholder
                feed_dict = {}
                feed_dict[skip_count] = [
                    0 for i in range(num_replicas_per_worker)
                ]
                init = train_model.iterator.initializer
                train_sess.run(init, feed_dict=feed_dict)

            if worker_id == 0:
                results = train_sess.run([
                    train_model.model.update, train_model.model.train_loss,
                    train_model.model.predict_count,
                    train_model.model.train_summary,
                    train_model.model.global_step,
                    train_model.model.word_count, train_model.model.batch_size,
                    train_model.model.grad_norm,
                    train_model.model.learning_rate
                ])
                step_result = [r[0] for r in results]

            else:
                global_step, _ = train_sess.run(
                    [train_model.model.global_step, train_model.model.update])
            hparams.epoch_step += 1

            if worker_id == 0:
                # Process step_result, accumulate stats, and write summary
                global_step, info["learning_rate"], step_summary = \
                    train.update_stats(stats, start_time, step_result)

                # Once in a while, we print statistics.
                if global_step - last_stats_step >= steps_per_stats:
                    last_stats_step = global_step
                    is_overflow = train.process_stats(stats, info, global_step,
                                                      steps_per_stats, log_f)
                    train.print_step_info("  ", global_step, info,
                                          train._get_best_results(hparams),
                                          log_f)
                    if is_overflow:
                        break

                    # Reset statistics
                    stats = train.init_stats()

    sess, num_workers, worker_id, num_replicas_per_worker = \
        parallax.parallel_run(train_model.graph,
                              FLAGS.resource_info_file,
                              sync=FLAGS.sync,
                              parallax_config=parallax_config.build_config())
    run(sess, num_workers, worker_id, num_replicas_per_worker)