Example #1
0
def secondary_fn_tmp(hparams, identity, model_dir, model, eval_model, eval_sess,
                     name, worker_fn):
  """secondary helper function for inference and evaluation."""
  steps_per_external_eval = 10
  # initialize summary writer
  summary_writer_path = os.path.join(hparams.out_dir, identity + name + "_log")
  print("summary_writer_path", summary_writer_path)
  summary_writer = tf.summary.FileWriter(summary_writer_path, model.graph)
  config_proto = utils.get_config_proto(
      log_device_placement=hparams.log_device_placement,
      allow_soft_placement=True)
  # create session
  sess = tf.Session(config=config_proto, graph=model.graph)

  # wait for the checkpoints
  latest_ckpt = None
  last_external_eval_step = 0

  # main inference loop
  while True:
    latest_ckpt = tf.contrib.training.wait_for_new_checkpoint(
        model_dir, latest_ckpt)
    with model.graph.as_default():
      _, global_step = model_helper.create_or_load_model(
          model.model, model_dir, sess, name)
    if global_step - last_external_eval_step >= steps_per_external_eval:
      last_external_eval_step = global_step
      worker_fn(model, sess, eval_model, eval_sess, latest_ckpt, summary_writer,
                global_step, hparams)
    if not hparams.eval_forever:
      break  # if eval_foever is disabled, we only evaluate once
  summary_writer.close()
  sess.close()
Example #2
0
    def get_answer(self, question, id_spk):
        infer_data = [clean_text(question)]
        infer_model = self.infer_model
        with tf.Session(
                graph=infer_model.graph,
                config=utils.get_config_proto()) as sess:
            loaded_infer_model = model_helper.load_model(
                    infer_model.model, self.ckpt, sess, "infer")
            sess.run(
                    infer_model.iterator.initializer,
                    feed_dict={
                        infer_model.src_placeholder: infer_data,
                        infer_model.batch_size_placeholder: 1,
                        infer_model.src_speaker_placeholder: id_spk,
                        infer_model.tgt_speaker_placeholder: self.id
                        })

            nmt_outputs, _ = loaded_infer_model.decode(sess)
            translation = []
            for beam_id in range(self.num_translations_per_input):

                # Set set_id to 0 because batch_size of 1
                translation.append(nmt_utils.get_translation(
                        nmt_outputs=nmt_outputs[beam_id],
                        sent_id=0,
                        tgt_eos=self.hparams.eos,
                        subword_option=self.hparams.subword_option))

        return translation
Example #3
0
    def __init__(self, hparams):
        self.hparams = hparams
        # print("====test__init__==\n")
        # Data locations
        self.out_dir = hparams.out_dir
        # print("our_dir:", self.out_dir)
        self.model_dir = os.path.join(self.out_dir, 'ckpts')
        # print("model_dir:", self.model_dir)
        # Create models
        attention_option = hparams.attention_option

        if attention_option:
            model_creator = AttentionModel
        else:
            model_creator = BasicModel

        self.infer_model = model_helper.create_infer_model(
            hparams=hparams, model_creator=model_creator)

        # Sessions
        config_proto = utils.get_config_proto()
        self.infer_sess = tf.Session(config=config_proto,
                                     graph=self.infer_model.graph)

        # EOS
        self.tgt_eos = Vocabulary.EOS.encode("utf-8")
        # Load infer model
        with self.infer_model.graph.as_default():
            self.loaded_infer_model, self.global_step = model_helper.create_or_load_model(
                self.infer_model.model, self.model_dir, self.infer_sess,
                "infer")
Example #4
0
def start_sess_and_load_model(infer_model, ckpt_path):
    """Start session and load model."""
    sess = tf.Session(graph=infer_model.graph, config=utils.get_config_proto())
    with infer_model.graph.as_default():
        loaded_infer_model = model_helper.load_model(infer_model.model,
                                                     ckpt_path, sess, "infer")
    return sess, loaded_infer_model
Example #5
0
def single_worker_inference(infer_model,
                            ckpt,
                            inference_input_file,
                            inference_output_file,
                            hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    with tf.Session(config=utils.get_config_proto(), graph=infer_model.graph) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder: hparams.infer_batch_size
                 })
        # Decode
        utils.print_out("# Start decoding")
        _decode_and_evaluate("infer",
                             loaded_infer_model,
                             sess,
                             output_infer,
                             ref_file=None,
                             subword_option=None,
                             beam_width=hparams.beam_width,
                             tgt_eos=hparams.eos,
                             num_translations_per_input=hparams.num_translations_per_input)
Example #6
0
def infer(hparams):
    infer_model = mc.create_infer_model(hparams)
    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=hparams.log_device_placement, allow_soft_placement=True)
    infer_sess = tf.Session(graph=infer_model.graph, config=config_proto)

    infer_util.run_infer(hparams, infer_sess, infer_model, 0, ["bleu", "rouge", "accuracy"])
Example #7
0
def export_model(config, model_creator):
    if not config.export_path:
        raise ValueError("Export path must be specified.")
    if not config.model_version:
        raise ValueError("Export model version must be specified.")

    utils.makedir(config.export_path)

    # Create model
    model = model_helper.create_model(model_creator, config, mode="infer")

    # TensorFlow model
    config_proto = utils.get_config_proto()
    sess = tf.Session(config=config_proto, graph=model.graph)

    with model.graph.as_default():
        loaded_model, global_step = model_helper.create_or_load_model(
            model.model, config.best_eval_loss_dir, sess, "infer")

        export_dir = os.path.join(config.export_path, config.model_version)
        builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
        inputs = {
            "word_ids1":
            tf.saved_model.utils.build_tensor_info(loaded_model.word_ids1),
            "word_ids2":
            tf.saved_model.utils.build_tensor_info(loaded_model.word_ids2),
            "word_len1":
            tf.saved_model.utils.build_tensor_info(loaded_model.word_len1),
            "word_len2":
            tf.saved_model.utils.build_tensor_info(loaded_model.word_len2),
            "char_ids1":
            tf.saved_model.utils.build_tensor_info(loaded_model.char_ids1),
            "char_ids2":
            tf.saved_model.utils.build_tensor_info(loaded_model.char_ids2),
            "char_len1":
            tf.saved_model.utils.build_tensor_info(loaded_model.char_len1),
            "char_len2":
            tf.saved_model.utils.build_tensor_info(loaded_model.char_len2)
        }
        outputs = {
            "simscore":
            tf.saved_model.utils.build_tensor_info(loaded_model.simscore)
        }
        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs=inputs,
                outputs=outputs,
                method_name=tf.saved_model.signature_constants.
                PREDICT_METHOD_NAME))

        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING], {
                tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                prediction_signature
            })
        builder.save()
        logger.info("Export model succeed.")
Example #8
0
def run_sample_decode_pungan_prepare(hparams, scope=None, target_session=""):
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    if not hparams.attention:  # choose this model
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)
    wsd_src_file = "%s" % (hparams.sample_prefix)
    wsd_src_data = inference.load_data(wsd_src_file)
    model_dir = hparams.out_dir
    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    eval_result = run_sample_decode_pungan(infer_model, infer_sess, model_dir,
                                           hparams, wsd_src_data)
    print('eval_result', eval_result)
    print('eval_result.len', len(eval_result))
    '''
    for i in range(0,256):
        if i%32 == 0:
            print('\n')
        print('eval_result[',i,']',eval_result[i])
    '''

    print('wsd_src_data', wsd_src_data)
    print('wsd_src_data.len', len(wsd_src_data))
    #len=16 wsd_src_data [u'problem%1:10:00::', u'problem%1:26:00::', u'drive%1:04:00::', u'drive%1:04:03::', u'identity%1:07:00::', u'identity%1:24:01::', u'point%1:10:01::', u'point%1:06:00::', u'tension%1:26:01::', u'tension%1:26:03::', u'log%2:32:00::', u'log%2:35:00::', u'fan%1:06:00::', u'fan%1:18:00::', u'file%2:32:00::', u'file%2:35:00::']

    eval_result_new = []
    for block in range(len(eval_result) / (2 * hparams.sample_size)):
        src_word1, src_word2 = wsd_src_data[2 *
                                            block], wsd_src_data[2 * block + 1]
        for sent_id in range(block * hparams.sample_size,
                             (block + 1) * hparams.sample_size):
            tgt_sent = src_word1.decode().encode(
                'utf-8') + ' ' + eval_result[sent_id]
            eval_result_new.append(tgt_sent)
    return wsd_src_data, eval_result_new
Example #9
0
def single_worker_inference(  #emb_matrix,
        infer_model, ckpt, inference_input_file, inference_output_file,
        hparams, model_creator):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)
    #saver = tf.train.Saver()
    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        #sess.run(model_creator._build_decoder.eval())
        # Decode
        #saver = tf.train.Saver()
        #emb=sess.run(emb_matrix)
        #fw=open('/home/yuzw/pun/nmt/inference/embedding_ds','w+')
        #fw.write('\n'.join(
        #           [' '.join([str(u) for u in e]) for e in emb]))
        #print("emb=sess.run(emb_matrix)",emb)
        #save_path = saver.save(sess, "/home/yuzw/pun/nmt/inference/emb.npz")
        #print("Model saved in path: %s" % save_path)
        utils.print_out("# Start decoding single_worker_inference")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)

        else:

            nmt_utils.decode_and_evaluate(
                "infer",
                loaded_infer_model,
                sess,
                output_infer,
                ref_file=None,
                metrics=hparams.metrics,
                subword_option=hparams.subword_option,
                beam_width=hparams.beam_width,
                tgt_eos=hparams.eos,
                num_translations_per_input=hparams.num_translations_per_input)
Example #10
0
def start_sess_and_load_model(infer_model, ckpt_path, hparams):
    """Start session and load model."""
    print("num_intra_threads = %d, num_inter_threads = %d \n" %
          (hparams.num_intra_threads, hparams.num_inter_threads))
    sess = tf.Session(graph=infer_model.graph,
                      config=utils.get_config_proto(
                          num_intra_threads=hparams.num_intra_threads,
                          num_inter_threads=hparams.num_inter_threads))
    with infer_model.graph.as_default():
        loaded_infer_model = model_helper.load_model(infer_model.model,
                                                     ckpt_path, sess, "infer")
    return sess, loaded_infer_model
Example #11
0
def test(config, model_creator):
    # for metric in config.metrics.split(","):
    best_metric_label = "best_eval_loss"
    model_dir = getattr(config, best_metric_label + "_dir")

    logger.info("Start evaluating saved best model on training-set.")
    eval_model = model_helper.create_model(model_creator, config, mode="eval")
    session_config = utils.get_config_proto()
    eval_sess = tf.Session(config=session_config, graph=eval_model.graph)
    run_test(config, eval_model, eval_sess, config.train_file, model_dir)

    logger.info("Start evaluating saved best model on dev-set.")
    eval_model = model_helper.create_model(model_creator, config, mode="eval")
    session_config = utils.get_config_proto()
    eval_sess = tf.Session(config=session_config, graph=eval_model.graph)
    run_test(config, eval_model, eval_sess, config.dev_file, model_dir)

    logger.info("Start evaluating saved best model on test-set.")
    eval_model = model_helper.create_model(model_creator, config, mode="eval")
    session_config = utils.get_config_proto()
    eval_sess = tf.Session(config=session_config, graph=eval_model.graph)
    run_test(config, eval_model, eval_sess, config.test_file, model_dir)
Example #12
0
def run_prediction(input_file_path, output_file_path):
    infile = 'input_file'
    word_split(input_file_path, infile, jieba_split)

    model_dir = 'jb_attention'
    hparams = utils.load_hparams(model_dir)
    hparams.inference_indices = [i for i in range(150)]
    sample_src_dataset = inference.load_data(infile)
    log_device_placement = hparams.log_device_placement

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:
        if (hparams.encoder_type == 'gnmt'
                or hparams.attention_architecture in ['gnmt', 'gnmt_v2']):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == 'standard':
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError('Unknown attention architecture %s' %
                             (hparams.attention_architecture))

    infer_model = model_helper.create_infer_model(model_creator,
                                                  hparams,
                                                  scope=None)

    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)

    infer_sess = tf.Session(target='',
                            config=config_proto,
                            graph=infer_model.graph)

    with infer_model.graph.as_default():
        loaded_infer_model, global_step = model_helper.create_or_load_model(
            infer_model.model, model_dir, infer_sess, 'infer')

    iterator_feed_dict = {
        infer_model.src_placeholder: sample_src_dataset,
        infer_model.batch_size_placeholder: 1,
    }
    infer_sess.run(infer_model.iterator.initalizer,
                   feed_dict=iterator_feed_dict)

    while True:
        try:
            nmt_outputs, _ = infer_model.decode(infer_sess)
        except tf.errors.OutOfRangeError:
            break
Example #13
0
    def __init__(self, hparams):
        self.hparams = hparams

        # Data locations
        self.out_dir = hparams.out_dir
        self.model_dir = os.path.join(self.out_dir, 'ckpts')
        if not tf.gfile.Exists(self.model_dir):
            tf.gfile.MakeDirs(self.model_dir)

        self.train_src_file = os.path.join(
            hparams.data_dir, hparams.train_prefix + '.' + hparams.src_suffix)
        self.train_tgt_file = os.path.join(
            hparams.data_dir, hparams.train_prefix + '.' + hparams.tgt_suffix)
        self.test_src_file = os.path.join(
            hparams.data_dir, hparams.test_prefix + '.' + hparams.src_suffix)
        self.test_tgt_file = os.path.join(
            hparams.data_dir, hparams.test_prefix + '.' + hparams.tgt_suffix)
        self.dev_src_file = os.path.join(
            hparams.data_dir, hparams.dev_prefix + '.' + hparams.src_suffix)
        self.dev_tgt_file = os.path.join(
            hparams.data_dir, hparams.dev_prefix + '.' + hparams.tgt_suffix)
        self.infer_out_file = os.path.join(self.out_dir, 'infer_output')
        self.eval_out_file = os.path.join(self.out_dir, 'eval_output')

        # Create models
        attention_option = hparams.attention_option

        if attention_option:
            model_creator = AttentionModel
        else:
            model_creator = BasicModel

        self.train_model = model_helper.create_train_model(
            hparams=hparams, model_creator=model_creator)
        self.eval_model = model_helper.create_eval_model(
            hparams=hparams, model_creator=model_creator)
        self.infer_model = model_helper.create_infer_model(
            hparams=hparams, model_creator=model_creator)

        # Sessions
        config_proto = utils.get_config_proto()
        self.train_sess = tf.Session(config=config_proto,
                                     graph=self.train_model.graph)
        self.eval_sess = tf.Session(config=config_proto,
                                    graph=self.eval_model.graph)
        self.infer_sess = tf.Session(config=config_proto,
                                     graph=self.infer_model.graph)

        # EOS
        self.tgt_eos = Vocabulary.EOS.encode("utf-8")
Example #14
0
def infer_fn(hparams, identity, scope=None, extra_args=None, target_session=""):
  """main entry point for inference and evaluation."""
  # create infer and eval models
  infer_model = model_helper.create_infer_model(
      diag_model.Model, hparams, scope, extra_args=extra_args)
  eval_model = model_helper.create_eval_model(diag_model.Model, hparams, scope)
  config_proto = utils.get_config_proto(
      log_device_placement=hparams.log_device_placement,
      allow_soft_placement=True)
  # create the eval session
  eval_sess = tf.Session(
      target=target_session, config=config_proto, graph=eval_model.graph)

  secondary_fn_tmp(hparams, identity, hparams.out_dir, infer_model, eval_model,
                   eval_sess, "infer", single_worker_inference)
Example #15
0
def inference(infer_data):
    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")

        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: [infer_data],
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })

        translation = decode_inference_indices(loaded_infer_model, sess)

    return translation
Example #16
0
def inference(config, model_creator):
    output_file = "output_" + os.path.split(config.infer_file)[-1].split(".")[0]
    # Inference output directory
    pred_file = os.path.join(config.model_dir, output_file)
    utils.makedir(pred_file)

    # Inference
    model_dir = config.best_eval_loss_dir

    # Create model
    # model_creator = my_model.MyModel
    infer_model = model_helper.create_model(model_creator, config, mode="infer")

    # TensorFlow model
    sess_config = utils.get_config_proto()
    infer_sess = tf.Session(config=sess_config, graph=infer_model.graph)

    with infer_model.graph.as_default():
        loaded_infer_model, _ = model_helper.create_or_load_model(
            infer_model.model, model_dir, infer_sess, "infer")

    run_infer(config, loaded_infer_model, infer_sess, pred_file)
Example #17
0
def single_worker_inference(infer_model, ckpt, inference_input_file,
                            inference_output_file, hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        # Decode
        utils.print_out("# Start decoding")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_eos=hparams.eos,
                bpe_delimiter=hparams.bpe_delimiter)
        else:
            nmt_utils.decode_and_evaluate("infer",
                                          loaded_infer_model,
                                          sess,
                                          output_infer,
                                          ref_file=None,
                                          metrics=hparams.metrics,
                                          bpe_delimiter=hparams.bpe_delimiter,
                                          beam_width=hparams.beam_width,
                                          tgt_eos=hparams.eos)
def single_worker_inference(infer_model, ckpt, inference_input_file,
                            inference_output_file, hparams):
    output_infer = inference_output_file

    infer_data = load_data(inference_input_file, hparams)

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_util.load_model(infer_model.model, ckpt,
                                                   sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        utils.print_out("# Start decoding")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)
        else:
            nmt_utils.decode_and_evaluate(
                "infer",
                loaded_infer_model,
                sess,
                output_infer,
                ref_file=None,
                metrics=hparams.metrics,
                subword_option=hparams.subword_option,
                beam_width=hparams.beam_width,
                tgt_eos=hparams.eos,
                num_translations_per_input=hparams.num_translations_per_input)
Example #19
0
def run_sample_decode_pungan_prepare(hparams, scope=None, target_session=""):
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    if not hparams.attention:  # choose this model
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    def dealt(input, output):
        with open(input) as f:
            with open(output, 'w') as fw:
                for line in f:
                    l = line.strip().split()
                    l.reverse()
                    sent = ' '.join(l)
                    fw.write(sent + '\n')

    wsd_src_file = "%s" % (hparams.sample_prefix)
    wsd_src_file_new = wsd_src_file + '.new'
    dealt(wsd_src_file, wsd_src_file_new)

    wsd_src_data = inference.load_data(wsd_src_file_new)
    model_dir = hparams.out_dir
    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)
    print('len wsd_src_data', len(wsd_src_data))
    eval_result = []
    for i in range(len(wsd_src_data) / 32):
        eval_result += run_sample_decode_pungan(
            infer_model, infer_sess, model_dir, hparams,
            wsd_src_data[i * 32:(i + 1) * 32])
    print('eval_result')
    print(eval_result)
    print(len(eval_result))
    backward_step1_in = []
    with open(PUNGAN_ROOT_PATH +
              '/Pun_Generation/data/1backward/backward_step1.in') as f:
        for line in f:
            backward_step1_in.append(line.strip())

    def wsd_input_format(wsd_src_data, eval_result):
        '''
        test_data[0] {'target_word': u'art#n', 'target_sense': None, 'id': 'senseval2.d000.s000.t000', 'context': ['the', '<target>', 'of', 'change_ringing', 'be', 'peculiar', 'to', 'the', 'english', ',', 'and', ',', 'like', 'most', 'english', 'peculiarity', ',', 'unintelligible', 'to', 'the', 'rest', 'of', 'the', 'world', '.'], 'poss': ['DET', 'NOUN', 'ADP', 'NOUN', 'VERB', 'ADJ', 'PRT', 'DET', 'NOUN', '.', 'CONJ', '.', 'ADP', 'ADJ', 'ADJ', 'NOUN', '.', 'ADJ', 'PRT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', '.']}
        '''
        wsd_input = []
        senses_input = []

        for i in range(len(eval_result)):
            block = i / 32
            src_word1, src_word2 = backward_step1_in[
                2 * block], backward_step1_in[2 * block + 1]
            tgt_sent = wsd_src_data[i].decode().encode(
                'utf-8') + ' ' + eval_result[i]
            tgt_word = src_word1

            synset = wn.lemma_from_key(tgt_word).synset()
            s = synset.name()
            target_word = '#'.join(s.split('.')[:2])
            context = tgt_sent.split(' ')

            for j in range(len(context)):
                if context[j] == tgt_word:
                    context[j] = '<target>'
            poss_list = ['.' for _ in range(len(context))]
            tmp_dict = {
                'target_word': target_word,
                'target_sense': None,
                'id': None,
                'context': context,
                'poss': poss_list
            }
            wsd_input.append(tmp_dict)
            senses_input.append((src_word1, src_word2))
        return wsd_input, senses_input

    wsd_input, senses_input = wsd_input_format(wsd_src_data, eval_result)
    print('wsd_input', wsd_input)
    print("len of wsd_input", len(wsd_input))
    return wsd_input, senses_input, wsd_src_data, eval_result
Example #20
0
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    # Create model
    model_creator = get_model_creator(hparams)
    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data,
                  avg_ckpts)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step, hparams, log_f)
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info, get_best_results(hparams),
                            log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            add_info_summaries(summary_writer, global_step, info)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    (result_summary, _, final_eval_metrics) = (run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data, avg_ckpts))
    print_step_info("# Final, ", global_step, info, result_summary, log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        print_step_info("# Best %s, " % metric, best_global_step, info,
                        result_summary, log_f)
        summary_writer.close()

        if avg_ckpts:
            best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir")
            summary_writer = tf.summary.FileWriter(
                os.path.join(best_model_dir, summary_name), infer_model.graph)
            result_summary, best_global_step, _ = run_full_eval(
                best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
                hparams, summary_writer, sample_src_data, sample_tgt_data)
            print_step_info("# Averaged Best %s, " % metric, best_global_step,
                            info, result_summary, log_f)
            summary_writer.close()

    return final_eval_metrics, global_step
Example #21
0
def train(hparams, identity, scope=None, target_session=""):
  """main loop to train the dialogue model. identity is used."""
  out_dir = hparams.out_dir
  steps_per_stats = hparams.steps_per_stats
  steps_per_internal_eval = 3 * steps_per_stats

  model_creator = diag_model.Model

  train_model = model_helper.create_train_model(model_creator, hparams, scope)

  model_dir = hparams.out_dir

  # Log and output files
  log_file = os.path.join(out_dir, identity+"log_%d" % time.time())
  log_f = tf.gfile.GFile(log_file, mode="a")
  utils.print_out("# log_file=%s" % log_file, log_f)

  avg_step_time = 0.0

  # load TensorFlow session and model
  config_proto = utils.get_config_proto(
      log_device_placement=hparams.log_device_placement,
      allow_soft_placement=True)

  train_sess = tf.Session(
      target=target_session, config=config_proto, graph=train_model.graph)

  train_handle = train_sess.run(train_model.train_iterator.string_handle())

  with train_model.graph.as_default():
    loaded_train_model, global_step = model_helper.create_or_load_model(
        train_model.model, model_dir, train_sess, "train")

  # initialize summary writer
  summary_writer = tf.summary.FileWriter(
      os.path.join(out_dir, "train_log"), train_model.graph)

  last_stats_step = global_step
  last_eval_step = global_step

  # initialize training stats.
  step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
  checkpoint_total_count = 0.0
  speed, train_ppl = 0.0, 0.0
  start_train_time = time.time()

  utils.print_out(
      "# Start step %d, lr %g, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       time.ctime()),
      log_f)

  # initialize iterators
  skip_count = hparams.batch_size * hparams.epoch_step
  utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
  train_sess.run(
      train_model.train_iterator.initializer,
      feed_dict={train_model.skip_count_placeholder: skip_count})

  # main training loop
  while global_step < hparams.num_train_steps:
    start_time = time.time()
    try:  #  run a step
      step_result = loaded_train_model.train(train_sess, train_handle)
      (_, step_loss, all_summaries, step_predict_count, step_summary,
       global_step, step_word_count, batch_size, _, _, words1, words2, mask1,
       mask2) = step_result
      hparams.epoch_step += 1

    except tf.errors.OutOfRangeError:  # finished an epoch
      hparams.epoch_step = 0
      utils.print_out("# Finished an epoch, step %d." % global_step)
      train_sess.run(
          train_model.train_iterator.initializer,
          feed_dict={train_model.skip_count_placeholder: 0})
      continue

    # Write step summary.
    summary_writer.add_summary(step_summary, global_step)
    for key in all_summaries:
      utils.add_summary(summary_writer, global_step, key, all_summaries[key])

    # update statistics
    step_time += (time.time() - start_time)

    checkpoint_loss += (step_loss * batch_size)
    checkpoint_predict_count += step_predict_count
    checkpoint_total_count += float(step_word_count)

    if global_step - last_stats_step >= steps_per_stats:
      # print statistics for the previous epoch and save the model.
      last_stats_step = global_step

      avg_step_time = step_time / steps_per_stats
      utils.add_summary(summary_writer, global_step, "step_time", avg_step_time)
      train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
      speed = checkpoint_total_count / (1000 * step_time)
      if math.isnan(train_ppl):
        break

      # Reset timer and loss.
      step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
      checkpoint_total_count = 0.0

      # save the model
      loaded_train_model.saver.save(
          train_sess,
          os.path.join(out_dir, "dialogue.ckpt"),
          global_step=global_step)

      # print the dialogue if in debug mode
      if hparams.debug:
        utils.print_current_dialogue(words1, words2, mask1, mask2)

    #  write out internal evaluation
    if global_step - last_eval_step >= steps_per_internal_eval:
      last_eval_step = global_step

      utils.print_out("# Internal Evaluation. global step %d" % global_step)
      utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

  # finished training
  loaded_train_model.saver.save(
      train_sess,
      os.path.join(out_dir, "dialogue.ckpt"),
      global_step=global_step)
  result_summary = ""
  utils.print_out(
      "# Final, step %d lr %g "
      "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       avg_step_time, speed, train_ppl, result_summary, time.ctime()),
      log_f)
  utils.print_time("# Done training!", start_train_time)
  utils.print_out("# Start evaluating saved best models.")
  summary_writer.close()
Example #22
0
def train(hps, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hps.log_device_placement
    out_dir = hps.out_dir
    num_train_steps = hps.num_train_steps
    steps_per_stats = hps.steps_per_stats
    steps_per_external_eval = hps.steps_per_external_eval
    steps_per_eval = 100 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if hps.attention_architecture == "baseline":
        model_creator = AttentionModel
    else:
        model_creator = AttentionHistoryModel

    train_model = model_helper.create_train_model(model_creator, hps, scope)
    eval_model = model_helper.create_eval_model(model_creator, hps, scope)
    infer_model = model_helper.create_infer_model(model_creator, hps, scope)

    # Preload data for sample decoding.

    article_filenames = []
    abstract_filenames = []
    art_dir = hps.data_dir + '/article'
    abs_dir = hps.data_dir + '/abstract'
    for file in os.listdir(art_dir):
        if file.startswith(hps.dev_prefix):
            article_filenames.append(art_dir + "/" + file)
    for file in os.listdir(abs_dir):
        if file.startswith(hps.dev_prefix):
            abstract_filenames.append(abs_dir + "/" + file)
    # if random_decode:
    #     """if this is a random sampling process during training"""
    decode_id = random.randint(0, len(article_filenames) - 1)
    single_article_file = article_filenames[decode_id]
    single_abstract_file = abstract_filenames[decode_id]

    dev_src_file = single_article_file
    dev_tgt_file = single_abstract_file
    sample_src_data = inference_base_model.load_data(dev_src_file)
    sample_tgt_data = inference_base_model.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hps.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    # run_full_eval(
    #     model_dir, infer_model, infer_sess,
    #     eval_model, eval_sess, hps,
    #     summary_writer,sample_src_data,sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats = init_stats()
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hps.batch_size * hps.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    train_sess.run(train_model.iterator.initializer,
                   feed_dict={train_model.skip_count_placeholder: skip_count})
    epoch_step = 0
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)

            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hps, summary_writer)
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary and accumulate statistics
        global_step = update_stats(stats, summary_writer, start_time,
                                   step_result)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = check_stats(stats, global_step, steps_per_stats, hps,
                                      log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "summarized.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hps,
                                                  summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "summarized.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hps, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "summarized.ckpt"),
                                  global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hps,
        summary_writer, sample_src_data, sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step,
         loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hps.metrics:
        best_model_dir = getattr(hps, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hps, summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out(
            "# Best %s, step %d "
            "step-time %.2f wps %.2fK, %s, %s" %
            (metric, best_global_step, avg_step_time, speed, result_summary,
             time.ctime()), log_f)
        summary_writer.close()

    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
def multi_worker_inference(infer_model, ckpt, inference_input_file,
                           inference_output_file, hparams, num_workers,
                           job_id):
    assert num_workers > 1

    final_output_infer = inference_output_file
    output_infer = "%s_%d" % (inference_output_file, job_id)
    output_infer_done = "%s_done_%d" % (inference_output_file, job_id)

    infer_data = load_data(inference_input_file, hparams)

    total_load = len(infer_data)
    load_per_worker = int((total_load - 1) / num_workers) + 1
    start_position = job_id * load_per_worker
    end_position = min(start_position + load_per_worker, total_load)
    infer_data = infer_data[start_position:end_position]

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_util.load_model(infer_model.model, ckpt,
                                                   sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        utils.print_out("# Start decoding")
        nmt_utils.decode_and_evaluate(
            "infer",
            loaded_infer_model,
            sess,
            output_infer,
            ref_file=None,
            metrics=hparams.metrics,
            subword_option=hparams.subword_option,
            beam_width=hparams.beam_width,
            tgt_eos=hparams.eos,
            num_translations_per_input=hparams.num_translations_per_input)

        tf.gfile.Rename(output_infer, output_infer_done, overwrite=True)

        if job_id != 0: return

        with codecs.getwriter("utf-8")(tf.gfile.GFile(final_output_infer,
                                                      mode="wb")) as final_f:
            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file,
                                                    worker_id)
                while not tf.gfile.Exists(worker_infer_done):
                    utils.print_out("  waitting job %d to complete." %
                                    worker_id)
                    time.sleep(10)

                with codecs.getreader("utf-8")(tf.gfile.GFile(
                        worker_infer_done, mode="rb")) as f:
                    for translation in f:
                        final_f.write("%s" % translation)

            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file,
                                                    worker_id)
                tf.gfile.Remove(worker_infer_done)
Example #24
0
def run_main(flags,
             default_hparams,
             train_fn,
             inference_fn,
             target_session=""):
    """Run main."""
    # Job
    jobid = flags.jobid
    num_workers = flags.num_workers
    utils.print_out("# Job id %d" % jobid)

    # Random
    random_seed = flags.random_seed
    if random_seed is not None and random_seed > 0:
        utils.print_out("# Set random seed to %d" % random_seed)
        random.seed(random_seed + jobid)
        np.random.seed(random_seed + jobid)

    # Model output directory
    out_dir = flags.out_dir
    if out_dir and not tf.gfile.Exists(out_dir):
        utils.print_out("# Creating output directory %s ..." % out_dir)
        tf.gfile.MakeDirs(out_dir)

    # Load hparams.
    loaded_hparams = False
    if flags.ckpt:  # Try to load hparams from the same directory as ckpt
        ckpt_dir = os.path.dirname(flags.ckpt)
        ckpt_hparams_file = os.path.join(ckpt_dir, "hparams")
        if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path:
            hparams = create_or_load_hparams(ckpt_dir,
                                             default_hparams,
                                             flags.hparams_path,
                                             save_hparams=False)
            loaded_hparams = True
    if not loaded_hparams:  # Try to load from out_dir
        assert out_dir
        hparams = create_or_load_hparams(out_dir,
                                         default_hparams,
                                         flags.hparams_path,
                                         save_hparams=(jobid == 0))


# GPU device
    config_proto = utils.get_config_proto(
        allow_soft_placement=True,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    utils.print_out("# Devices visible to TensorFlow: %s" %
                    repr(tf.Session(config=config_proto).list_devices()))

    ## Train / Decode
    if flags.inference_input_file:
        # Inference output directory
        trans_file = flags.inference_output_file
        assert trans_file
        trans_dir = os.path.dirname(trans_file)
        if not tf.gfile.Exists(trans_dir): tf.gfile.MakeDirs(trans_dir)

        # Inference indices
        hparams.inference_indices = None
        if flags.inference_list:
            (hparams.inference_indices) = ([
                int(token) for token in flags.inference_list.split(",")
            ])

        # Inference
        ckpt = flags.ckpt
        if not ckpt:
            ckpt = tf.train.latest_checkpoint(out_dir)
        inference_fn(flags.run, flags.iterations, ckpt,
                     flags.inference_input_file, trans_file, hparams,
                     num_workers, jobid)

        # Evaluation
        if flags.run == 'accuracy':
            ref_file = flags.inference_ref_file
            if ref_file and tf.gfile.Exists(trans_file):
                for metric in hparams.metrics:
                    score = evaluation_utils.evaluate(ref_file, trans_file,
                                                      metric,
                                                      hparams.subword_option)
                    utils.print_out("  %s: %.1f" % (metric, score))
            else:
                # Train
                train_fn(hparams, target_session=target_session)
Example #25
0
def train(hparams, scope=None, target_session="", compute_ppl=0):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:  # choose this model
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)
    wsd_src_file = "%s" % (hparams.sample_prefix)

    wsd_src_data = inference.load_data(wsd_src_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    '''
  run_full_eval(
      model_dir, infer_model, infer_sess,
      eval_model, eval_sess, hparams,
      summary_writer, sample_src_data,
      sample_tgt_data, avg_ckpts)
  '''
    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step, hparams, log_f)
    end_step = global_step + 100
    while global_step < end_step:  # num_train_steps
        ### Run a step ###
        start_time = time.time()
        try:
            # then forward inference result to WSD, get reward
            step_result = loaded_train_model.train(train_sess)
            # forward reward to placeholder of loaded_train_model, and write a new train function where loss = loss*reward
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)

            # run_sample_decode(infer_model, infer_sess, model_dir, hparams,
            #                   summary_writer, sample_src_data, sample_tgt_data)

            # only for pretrain
            # run_external_eval(infer_model, infer_sess, model_dir, hparams,
            #                   summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})

            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result, hparams)
        summary_writer.add_summary(step_summary, global_step)
        if compute_ppl:
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)
        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info,
                            _get_best_results(hparams), log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              info["train_ppl"])

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)
    '''
Example #26
0
def multi_worker_inference(infer_model, ckpt, inference_input_file,
                           inference_output_file, hparams, num_workers, jobid):
    """Inference using multiple workers."""
    assert num_workers > 1

    final_output_infer = inference_output_file
    output_infer = "%s_%d" % (inference_output_file, jobid)
    output_infer_done = "%s_done_%d" % (inference_output_file, jobid)

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    # Split data to multiple workers
    total_load = len(infer_data)
    load_per_worker = int((total_load - 1) / num_workers) + 1
    start_position = jobid * load_per_worker
    end_position = min(start_position + load_per_worker, total_load)
    infer_data = infer_data[start_position:end_position]

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:

        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 {infer_model.src_placeholder: infer_data})

        # Decode
        utils.print_out("# Start decoding")
        nmt_utils.decode_and_evaluate("infer",
                                      loaded_infer_model,
                                      sess,
                                      output_infer,
                                      ref_file=None,
                                      metrics=hparams.metrics,
                                      bpe_delimiter=hparams.bpe_delimiter,
                                      beam_width=hparams.beam_width,
                                      tgt_eos=hparams.eos)

        # Change file name to indicate the file writing is completed.
        tf.gfile.Rename(output_infer, output_infer_done, overwrite=True)

        # Job 0 is responsible for the clean up.
        if jobid != 0: return

        # Now write all translations
        with codecs.getwriter("utf-8")(tf.gfile.GFile(final_output_infer,
                                                      mode="wb")) as final_f:
            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file,
                                                    worker_id)
                while not tf.gfile.Exists(worker_infer_done):
                    utils.print_out("  waitting job %d to complete." %
                                    worker_id)
                    time.sleep(10)

                with codecs.getreader("utf-8")(tf.gfile.GFile(
                        worker_infer_done, mode="rb")) as f:
                    for translation in f:
                        final_f.write("%s" % translation)
                tf.gfile.Remove(worker_infer_done)
def train(hparams):
    """Train a seq2seq model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats

    model_creator = model.Model

    train_model = model_helper.create_train_model(model_creator, hparams)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_files=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(config=config_proto, graph=train_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(
        os.path.join(out_dir, summary_name), train_model.graph)

    last_stats_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(
        loaded_train_model, train_model, train_sess, global_step, hparams, log_f)
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out("# Finished an epoch, step %d." % global_step)

            train_sess.run(
                train_model.iterator.initializer,
                feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f)
            print_step_info("  ", global_step, info, log_f)

            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

    # Done training
    loaded_train_model.saver.save(
        train_sess,
        os.path.join(out_dir, "seq2seq.ckpt"),
        global_step=global_step)

    summary_writer.close()

    return global_step
Example #28
0
def multi_worker_selfplay(hparams,
                          identity,
                          scope=None,
                          target_session='',
                          is_chief=True,
                          ps_tasks=0,
                          num_workers=1,
                          jobid=0,
                          startup_delay_steps=0):
    """This is the multi worker selfplay, mostly used for self play

  distributed training.
  identity is used.
  """
    immutable_model_reload_freq = hparams.immutable_model_reload_freq
    # 1. models and summary writer
    model_creator = diag_model.Model
    extra_args = model_helper.ExtraArgs(
        single_cell_fn=None,
        model_device_fn=tf.train.replica_device_setter(ps_tasks),
        attention_mechanism_fn=None)

    mutable_model = model_helper.create_selfplay_model(model_creator,
                                                       is_mutable=True,
                                                       num_workers=num_workers,
                                                       jobid=jobid,
                                                       hparams=hparams,
                                                       scope=scope,
                                                       extra_args=extra_args)
    immutable_hparams = copy.deepcopy(hparams)
    immutable_hparams.num_gpus = 0
    immutable_model = model_helper.create_selfplay_model(
        model_creator,
        is_mutable=False,
        num_workers=num_workers,
        jobid=jobid,
        hparams=immutable_hparams,
        scope=scope)

    if hparams.self_play_immutable_gpu:
        print('using GPU for immutable')
        immutable_sess = tf.Session(
            graph=immutable_model.graph,
            config=tf.ConfigProto(allow_soft_placement=True))
    else:
        print('not using GPU for immutable')
        immutable_sess = tf.Session(graph=immutable_model.graph,
                                    config=tf.ConfigProto(
                                        allow_soft_placement=True,
                                        device_count={'GPU': 0}))

    immutable_model, immutable_sess = load_self_play_model(
        immutable_model, immutable_sess, 'immutable',
        hparams.self_play_pretrain_dir, hparams.out_dir)
    global_step = immutable_model.model.global_step.eval(
        session=immutable_sess)

    if is_chief:
        ckpt = tf.train.latest_checkpoint(hparams.out_dir)
        if not ckpt:
            print('global_step, saving pretrain model to hparams.out_dir',
                  global_step, hparams.out_dir)
            immutable_model.model.saver.save(  # this is the prevent adam error
                immutable_sess,
                os.path.join(hparams.out_dir, 'dialogue.ckpt'),
                global_step=global_step)
            print('save finished')

    if is_chief:
        summary_writer_path = os.path.join(
            hparams.out_dir, identity + task_SP_DISTRIBUTED + '_log')
        summary_writer = tf.summary.FileWriter(summary_writer_path,
                                               mutable_model.graph)
        print('summary writer established at', summary_writer_path)
    else:
        summary_writer = None
    # 2. supervisor and sessions

    sv = tf.train.Supervisor(
        graph=mutable_model.graph,
        is_chief=is_chief,
        saver=mutable_model.model.saver,
        save_model_secs=0,  # disable automatic save checkpoints
        summary_op=None,
        logdir=hparams.out_dir,
        checkpoint_basename='dialogue.ckpt')

    mutable_config = utils.get_config_proto(
        log_device_placement=hparams.log_device_placement,
        allow_soft_placement=True)
    mutable_config.device_count['GPU'] = hparams.num_gpus

    mutable_sess = sv.prepare_or_wait_for_session(target_session,
                                                  config=mutable_config)

    # 3. additiona preparations
    global_step = mutable_model.model.global_step.eval(session=mutable_sess)
    while global_step < (jobid * (jobid + 1) * startup_delay_steps / 2):
        time.sleep(1)
        global_step = mutable_model.model.global_step.eval(
            session=mutable_sess)

    # save first model
    if is_chief:
        print('saveing the first checkpoint to', hparams.out_dir)
        mutable_model.model.saver.save(mutable_sess,
                                       os.path.join(hparams.out_dir,
                                                    'dialogue.ckpt'),
                                       global_step=global_step)
        last_save_step = global_step

    # Read data
    selfplay_data = dialogue_utils.load_data(hparams.self_play_train_data)
    selfplay_kb = dialogue_utils.load_data(hparams.self_play_train_kb)

    dialogue = SelfplayDialogue(mutable_model,
                                immutable_model,
                                mutable_sess,
                                immutable_sess,
                                hparams.max_dialogue_turns,
                                hparams.train_threadhold,
                                hparams.start_of_turn1,
                                hparams.start_of_turn2,
                                hparams.end_of_dialogue,
                                summary_writer=summary_writer,
                                dialogue_mode=task_SP_DISTRIBUTED,
                                hparams=hparams)

    # 4. main loop
    last_immmutable_model_reload = global_step
    last_save_step = global_step
    batch_size = dialogue.batch_size
    assert batch_size <= len(selfplay_data)

    # this is the start point of the self-play data. force shuffling at the beginning
    i = len(selfplay_data)
    train_stats = [0, 0]
    while global_step < hparams.num_self_play_train_steps:
        # a. reload immutable model, muttable will be automated managed by supervisor
        if immutable_model_reload_freq > 0 and global_step - last_immmutable_model_reload > immutable_model_reload_freq:
            immutable_model, immutable_sess = load_self_play_model(
                immutable_model, immutable_sess, 'immutable',
                hparams.self_play_pretrain_dir, hparams.out_dir)
            last_immmutable_model_reload = global_step
        # b. possiblely flip between speakers (or roll out models),
        # based on either a random policy or by step counts
        agent1, agent2, mutable_agent_index = dialogue.flip_agent(
            (mutable_model, mutable_sess, dialogue.mutable_handles),
            (immutable_model, immutable_sess, dialogue.immutable_handles))
        train_stats[mutable_agent_index] += 1
        # read selfplay data
        start_time = time.time()
        if i * batch_size + batch_size > len(selfplay_data):  # reacehd the end
            input_data = zip(selfplay_data, selfplay_kb)
            random.shuffle(input_data)  # random shuffle input data
            i = 0
            selfplay_data, selfplay_kb = zip(*input_data)

        start_ind, end_ind = i * batch_size, i * batch_size + batch_size
        batch_data, batch_kb = selfplay_data[start_ind:end_ind], selfplay_kb[
            start_ind:end_ind]
        train_example, _, _ = dialogue.talk(hparams.max_dialogue_len,
                                            batch_data, batch_kb, agent1,
                                            agent2, batch_size, global_step)
        possible_global_step = dialogue.maybe_train(train_example,
                                                    mutable_agent_index,
                                                    global_step,
                                                    force=True)
        if possible_global_step:
            global_step = possible_global_step
        if is_chief and global_step - last_save_step > hparams.self_play_dist_save_freq:
            mutable_model.model.saver.save(mutable_sess,
                                           os.path.join(
                                               hparams.out_dir,
                                               'dialogue.ckpt'),
                                           global_step=global_step)
            last_save_step = global_step
        end_time = time.time()

        if is_chief:
            utils.add_summary(summary_writer, global_step,
                              task_SP_DISTRIBUTED + '_' + 'time',
                              end_time - start_time)
            utils.add_summary(summary_writer, global_step,
                              task_SP_DISTRIBUTED + '_' + 'train_ratio',
                              train_stats[0] * 1.0 / (train_stats[1] + 0.1))
        i += 1

    if is_chief:
        summary_writer.close()

    mutable_sess.close()
    immutable_sess.close()
Example #29
0
def eval_fn(hparams, scope=None, target_session=""):
    """Evaluate a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    # First evaluation
    ckpt_size = len(
        tf.train.get_checkpoint_state(model_dir).all_model_checkpoint_paths)
    for ckpt_index in range(ckpt_size):
        train.run_full_eval(model_dir,
                            infer_model,
                            infer_sess,
                            eval_model,
                            eval_sess,
                            hparams,
                            None,
                            sample_src_data,
                            sample_tgt_data,
                            avg_ckpts,
                            ckpt_index=ckpt_index)
Example #30
0
def train(hparams, scope=None, target_session=''):
    """Train the chatbot"""
    # Initialize some local hyperparameters
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if hparams.architecture == "simple":
        model_creator = SimpleModel
        get_infer_iterator = iterator_utils.get_infer_iterator
        get_iterator = iterator_utils.get_iterator
    elif hparams.architecture == "hier":
        model_creator = HierarchicalModel
        # Parse some of the arguments now
        def curry_get_infer_iterator(dataset, vocab_table, batch_size, src_reverse,
                       eos, src_max_len):
            return end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos,
                                                      src_max_len=src_max_len, eou=hparams.eou,
                                                      dialogue_max_len=hparams.dialogue_max_len)
        get_infer_iterator = curry_get_infer_iterator

        def curry_get_iterator(src_dataset,
                 tgt_dataset,
                 vocab_table,
                 batch_size,
                 sos,
                 eos,
                 src_reverse,
                 random_seed,
                 num_buckets,
                 src_max_len=None,
                 tgt_max_len=None,
                 num_threads=4,
                 output_buffer_size=None,
                 skip_count=None):
            return end2end_iterator_utils.get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos,
                                                eou=hparams.eou, src_reverse=src_reverse, random_seed=random_seed,
                                                num_dialogue_buckets=num_buckets, src_max_len=src_max_len,
                                                tgt_max_len=tgt_max_len, num_threads=num_threads,
                                                output_buffer_size=output_buffer_size, skip_count=skip_count)

        get_iterator = curry_get_iterator
    else:
        raise ValueError("Unkown architecture", hparams.architecture)

    # Create three models which share parameters through the use of checkpoints
    train_model = create_train_model(model_creator, get_iterator, hparams, scope)
    eval_model = create_eval_model(model_creator, get_iterator, hparams, scope)
    infer_model = inference.create_infer_model(model_creator, get_infer_iterator, hparams, scope)
    # ToDo: adapt for architectures
    # Preload the data to use for sample decoding

    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # Create the configurations for the sessions
    config_proto = utils.get_config_proto(log_device_placement=log_device_placement)
    # Create three sessions, one for each model
    train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph)
    eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph)

    # Load the train model from checkpoint or create a new one
    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(train_model.model, model_dir,
                                                                            train_sess, name="train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(
        os.path.join(out_dir, summary_name), train_model.graph)
    # First evaluation
    run_full_eval(
        model_dir, infer_model, infer_sess,
        eval_model, eval_sess, hparams,
        summary_writer, sample_src_data,
        sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    # Initialize the hyperparameters for the loop.
    step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
    checkpoint_total_count = 0.0
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
         time.ctime()),
        log_f)

    # epoch_step records where we were within an epoch. Used to skip trained on examples
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    # Initialize the training iterator
    train_sess.run(
        train_model.iterator.initializer,
        feed_dict={train_model.skip_count_placeholder: skip_count})

    # Train until we reach num_steps.
    while global_step < num_train_steps:
        # Run a step
        start_step_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            (_, step_loss, step_predict_count, step_summary, global_step,  # The _ is the output of the update op
             step_word_count, batch_size) = step_result
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            # Decode and print a random sentence
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Perform external evaluation to save checkpoints if this is the best for some metric
            dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                                                                 summary_writer, save_on_best_dev=True)
            # Reinitialize the iterator from the beginning
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary.
        summary_writer.add_summary(step_summary, global_step)

        # update statistics
        step_time += (time.time() - start_step_time)

        checkpoint_loss += (step_loss * batch_size)
        checkpoint_predict_count += step_predict_count
        checkpoint_total_count += float(step_word_count)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step

            # Print statistics for the previous epoch.
            avg_step_time = step_time / steps_per_stats
            train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
            speed = checkpoint_total_count / (1000 * step_time)
            utils.print_out(
                "  global step %d lr %g "
                "step-time %.2fs wps %.2fK ppl %.2f %s" %
                (global_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, _get_best_results(hparams)),
                log_f)
            if math.isnan(train_ppl):
                # The model has screwed up
                break

            # Reset timer and loss.
            step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
            checkpoint_total_count = 0.0

        if global_step - last_eval_step >= steps_per_eval:
            # Perform evaluation. Start by reassigning the last_eval_step variable to the current step
            last_eval_step = global_step
            # Print the progress and add summary
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step)
            # Decode and print a random sample
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Run internal evaluation, and update the ppl variables. The data iterator is instantieted in the method.
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            # Run the external evaluation
            last_external_eval_step = global_step
            # Save checkpoint
            loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step)
            # Decode and print a random sample
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Run external evaluation, updating metric scores in the meanwhile. The unneeded output is the global step.
            dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                                                                 summary_writer, save_on_best_dev=True)

    # Done training. Save the model
    loaded_train_model.saver.save(
        train_sess,
        os.path.join(out_dir, "chatbot.ckpt"),
        global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess,
        eval_model, eval_sess, hparams,
        summary_writer, sample_src_data,
        sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()),
        log_f)
    utils.print_time("# Done training!", start_train_time)

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
            summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out("# Best %s, step %d "
                        "step-time %.2f wps %.2fK, %s, %s" %
                        (metric, best_global_step, avg_step_time, speed,
                         result_summary, time.ctime()), log_f)

    summary_writer.close()
    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)