def decode_and_evaluate(name,
                        model,
                        sess,
                        trans_file,
                        ref_file,
                        metrics,
                        subword_option,
                        beam_width,
                        tgt_eos,
                        num_translations_per_input=1,
                        decode=True):
  """Decode a test set and compute a score according to the evaluation task."""
  # Decode
  if decode:
    utils.print_out("  decoding to output %s." % trans_file)

    start_time = time.time()
    num_sentences = 0
    with codecs.getwriter("utf-8")(
        tf.gfile.GFile(trans_file, mode="wb")) as trans_f:
      trans_f.write("")  # Write empty string to ensure file is created.

      num_translations_per_input = max(
          min(num_translations_per_input, beam_width), 1)
      while True:
        try:
          nmt_outputs, _ = model.decode(sess)
          if beam_width == 0:
            nmt_outputs = np.expand_dims(nmt_outputs, 0)

          batch_size = nmt_outputs.shape[1]
          num_sentences += batch_size

          for sent_id in range(batch_size):
            for beam_id in range(num_translations_per_input):
              translation = get_translation(
                  nmt_outputs[beam_id],
                  sent_id,
                  tgt_eos=tgt_eos,
                  subword_option=subword_option)
              trans_f.write((translation + b"\n").decode("utf-8"))
        except tf.errors.OutOfRangeError:
          utils.print_time(
              "  done, num sentences %d, num translations per input %d" %
              (num_sentences, num_translations_per_input), start_time)
          break

  # Evaluation
  evaluation_scores = {}
  if ref_file and tf.gfile.Exists(trans_file):
    for metric in metrics:
      score = evaluation_utils.evaluate(
          ref_file,
          trans_file,
          metric,
          subword_option=subword_option)
      evaluation_scores[metric] = score
      utils.print_out("  %s %s: %.1f" % (metric, name, score))

  return evaluation_scores
Beispiel #2
0
def do_evaluate(arg):
    test_data = arg.test_data
    ground_truth_file = arg.ground_truth
    trans_file = arg.trans_file
    with tf.Session() as sess:
        saver.restore(sess, checkpoint_path)
        test_gen = generate_batch(test_data, batch_size=batch_size)
        # ground_truth_file = 'ground_truth_file'
        # trans_file = 'trans_file'
        if os.path.isfile(ground_truth_file):
            os.remove(ground_truth_file)
            os.remove(trans_file)
        f1 = open(ground_truth_file, 'a')
        f2 = open(trans_file, 'a')
        for j in range(int(128 / batch_size)):
            encoder_inputs_test, decoder_inputs_test, _, encoder_length_test, _ = next(
                test_gen)
            test_out = sess.run(predict_out_id,
                                feed_dict={
                                    encoder_inputs_x: encoder_inputs_test,
                                    source_sequence_length: encoder_length_test
                                })
            test_out_no1 = test_out[:, 0, :]
            ground_truth = id2words(decoder_inputs_test)
            sentences = id2words(test_out_no1)
            for line1 in ground_truth:
                f1.write(line1 + '\n')
            for line2 in sentences:
                f2.write(line2 + '\n')
        f1.close()
        f2.close()
        bleu_score = evaluate(ground_truth_file, trans_file, 'bleu')
        logger.info('evaluation done! out_path:{}'.format(trans_file) +
                    "\tbleu_score:{:.5f}".format(bleu_score))
Beispiel #3
0
  def testEvaluate(self):
    output = "testdata/deen_output"
    ref_bpe = "testdata/deen_ref_bpe"
    ref_spm = "testdata/deen_ref_spm"
    chinese = 'testdata/cutdev128.zh'
    chinese2 = 'testdata/dev2.zh'
    expected_bleu_score = 22.5855084573
    expected_rouge_score = 50.8429782599

    bpe_bleu_score = evaluation_utils.evaluate(
        ref_bpe, output, "bleu", "bpe")
    bpe_rouge_score = evaluation_utils.evaluate(
        ref_bpe, output, "rouge", "bpe")

    bleu_score = evaluation_utils.evaluate(
       chinese,chinese2,"bleu")
    print(bleu_score)
    def testAccuracy(self):
        pred_output = "nmt/testdata/pred_output"
        label_ref = "nmt/testdata/label_ref"

        expected_accuracy_score = 60.00

        accuracy_score = evaluation_utils.evaluate(label_ref, pred_output,
                                                   "accuracy")
        self.assertAlmostEqual(expected_accuracy_score, accuracy_score)
Beispiel #5
0
def decode_and_evaluate(name,
                        model,
                        sess,
                        trans_file,
                        ref_file,
                        metrics,
                        bpe_delimiter,
                        beam_width,
                        tgt_sos,
                        tgt_eos,
                        decode=True):
    """Decode a test set and compute a score according to the evaluation task."""
    # Decode
    if decode:
        utils.print_out("  decoding to output %s." % trans_file)

        start_time = time.time()
        num_sentences = 0
        with codecs.getwriter("utf-8")(tf.gfile.GFile(trans_file,
                                                      mode="wb")) as trans_f:
            trans_f.write("")  # Write empty string to ensure file is created.

            while True:
                try:
                    nmt_outputs, _, _ = model.decode(sess)

                    if beam_width > 0:
                        # get the top translation.
                        nmt_outputs = nmt_outputs[0]

                    num_sentences += len(nmt_outputs)
                    for sent_id in range(len(nmt_outputs)):
                        translation = get_translation(
                            nmt_outputs,
                            sent_id,
                            tgt_sos=tgt_sos,
                            tgt_eos=tgt_eos,
                            bpe_delimiter=bpe_delimiter)
                        trans_f.write((translation + b"\n").decode("utf-8"))
                except tf.errors.OutOfRangeError:
                    utils.print_time(
                        "  done, num sentences %d" % num_sentences, start_time)
                    break

    # Evaluation
    evaluation_scores = {}
    if ref_file and tf.gfile.Exists(trans_file):
        for metric in metrics:
            score = evaluation_utils.evaluate(ref_file,
                                              trans_file,
                                              metric,
                                              bpe_delimiter=bpe_delimiter)
            evaluation_scores[metric] = score
            utils.print_out("  %s %s: %.1f" % (metric, name, score))

    return evaluation_scores
def decode_and_evaluate(name,
                        model,
                        sess,
                        output_file,
                        reference_file,
                        metrics,
                        bpe_delimiter,
                        beam_width,
                        eos,
                        number_token=None,
                        name_token=None,
                        decode=True):
    """Decode a test set and compute a score according to the evaluation task."""
    # Decode
    if decode:
        utils.print_out("  decoding to output %s." % output_file)
        start_time = time.time()
        num_sentences = 0
        with tf.gfile.GFile(output_file, mode="w+") as out_f:
            out_f.write("")  # Write empty string to ensure file is created.

            while True:
                try:
                    # Get the response(s) for each input in the batch (whole file in this case)
                    # ToDo: adapt for architectures
                    outputs, infer_summary = model.decode(sess)

                    if beam_width > 0:
                        # Get the top response if we used beam_search
                        outputs = outputs[0]

                    num_sentences += len(outputs)
                    # Iterate over the outputs an write them to file
                    for sent_id in range(len(outputs)):
                        response = postprocess_output(outputs, sent_id, eos,
                                                      bpe_delimiter,
                                                      number_token, name_token)
                        out_f.write("%s\n" % response)
                except tf.errors.OutOfRangeError:
                    utils.print_time(
                        "  done, num sentences %d" % num_sentences, start_time)
                    break

    # Evaluation
    evaluation_scores = {}
    if reference_file and tf.gfile.Exists(output_file):
        for metric in metrics:
            score = evaluation_utils.evaluate(ref_file=reference_file,
                                              trans_file=output_file,
                                              metric=metric,
                                              bpe_delimiter=bpe_delimiter)
            evaluation_scores[metric] = score
            utils.print_out("  %s %s: %.1f" % (metric, name, score))

    return evaluation_scores
Beispiel #7
0
def run_main(flags,
             default_hparams,
             train_fn,
             inference_fn,
             target_session=""):
    """Run main."""
    # Job
    jobid = flags.jobid
    num_workers = flags.num_workers
    utils.print_out("# Job id %d" % jobid)

    # Random
    random_seed = flags.random_seed
    if random_seed is not None and random_seed > 0:
        utils.print_out("# Set random seed to %d" % random_seed)
        random.seed(random_seed + jobid)
        np.random.seed(random_seed + jobid)

    ## Train / Decode
    out_dir = flags.out_dir
    if not tf.gfile.Exists(out_dir): tf.gfile.MakeDirs(out_dir)

    # Load hparams.
    hparams = create_or_load_hparams(out_dir, default_hparams,
                                     flags.hparams_path)

    if flags.inference_input_file:
        # Inference indices
        hparams.inference_indices = None
        if flags.inference_list:
            (hparams.inference_indices) = ([
                int(token) for token in flags.inference_list.split(",")
            ])

        # Inference
        trans_file = flags.inference_output_file
        ckpt = flags.ckpt
        if not ckpt:
            ckpt = tf.train.latest_checkpoint(out_dir)
        inference_fn(ckpt, flags.inference_input_file, trans_file, hparams,
                     num_workers, jobid)

        # Evaluation
        ref_file = flags.inference_ref_file
        if ref_file and tf.gfile.Exists(trans_file):
            for metric in hparams.metrics:
                score = evaluation_utils.evaluate(ref_file, trans_file, metric,
                                                  hparams.bpe_delimiter)
                utils.print_out("  %s: %.1f" % (metric, score))
    else:
        # Train
        train_fn(hparams, target_session=target_session)
    def testEvaluate(self):
        output = "nmt/testdata/deen_output"
        ref_bpe = "nmt/testdata/deen_ref_bpe"
        ref_spm = "nmt/testdata/deen_ref_spm"

        expected_bleu_score = 22.5855084573
        expected_rouge_score = 50.8429782599

        bpe_bleu_score = evaluation_utils.evaluate(ref_bpe, output, "bleu",
                                                   "bpe")
        bpe_rouge_score = evaluation_utils.evaluate(ref_bpe, output, "rouge",
                                                    "bpe")

        self.assertAlmostEqual(expected_bleu_score, bpe_bleu_score)
        self.assertAlmostEqual(expected_rouge_score, bpe_rouge_score)

        spm_bleu_score = evaluation_utils.evaluate(ref_spm, output, "bleu",
                                                   "spm")
        spm_rouge_score = evaluation_utils.evaluate(ref_spm, output, "rouge",
                                                    "spm")

        self.assertAlmostEqual(expected_rouge_score, spm_rouge_score)
        self.assertAlmostEqual(expected_bleu_score, spm_bleu_score)
def decode_and_evaluate(mode,
                        sess,
                        out_tensor,
                        trans_file,
                        ref_file,
                        metric='bleu',
                        beam_width=10,
                        num_translations_per_input=1,
                        iterations=1):
    """Decode a test set and compute a score according to the evaluation task."""
    utils.print_out("  Decoding to output %s" % trans_file)

    with codecs.getwriter("utf-8")(tf.io.gfile.GFile(trans_file,
                                                     mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        num_translations_per_input = min(num_translations_per_input,
                                         beam_width)

        print("  Running inference with beam_width %g, num translations per input %d. " \
              % (beam_width, num_translations_per_input))
        print("  Total iterations count %d." % iterations)

        # Warmup for the first batch to take out the very first runtime
        # session overhead.
        nmt_outputs = sess.run(out_tensor)  # time x batch_size x beam_width
        nmt_outputs = nmt_outputs.transpose()  # beam_width x batch_size x time
        batch_size = nmt_outputs.shape[1]
        for sent_id in range(batch_size):
            translation = get_translation(nmt_outputs[0],
                                          sent_id,
                                          tgt_eos='</s>')
            if mode == 'accuracy':
                trans_f.write((translation + b"\n").decode("utf-8"))

        # prediction time is the time for the model prediction only
        # overall time is the time for data pre-processing and data post-processing
        prediction_times = list()
        overall_start = time.time()
        num_sentences = 0
        n = 0
        while n < iterations:
            n += 1
            while True:
                try:
                    start = time.time()
                    nmt_outputs = sess.run(
                        out_tensor)  # time x batch_size x beam_width
                    nmt_outputs = nmt_outputs.transpose(
                    )  # beam_width x batch_size x time
                    prediction_times.append(time.time() - start)
                    batch_size = nmt_outputs.shape[1]
                    num_sentences += batch_size
                    for sent_id in range(batch_size):
                        for beam_id in range(num_translations_per_input):
                            translation = get_translation(nmt_outputs[beam_id],
                                                          sent_id,
                                                          tgt_eos='</s>')
                            if mode == 'accuracy':
                                trans_f.write(
                                    (translation + b"\n").decode("utf-8"))

                except tf.errors.OutOfRangeError:
                    utils.print_time(
                        "  Done, num sentences %d, num translations per input %d"
                        % (num_sentences, num_translations_per_input),
                        overall_start)
                    break

    overall_time = (time.time() - overall_start)
    print("\nAverage Prediction Latency: {:.5f} sec per batch.".format(
        sum(prediction_times) / float(len(prediction_times))))
    print("Overall Latency: {:.5f} sec for the entire test "
          "dataset.".format(overall_time / float(iterations)))
    print("Overall Throughput : {:.3f} sentences per sec.".format(
        num_sentences / float(overall_time)))

    # Evaluation
    if mode == 'accuracy':
        if ref_file and tf.io.gfile.exists(trans_file):
            score = evaluation_utils.evaluate(ref_file, trans_file, metric)
            utils.print_out("  Accuracy metric %s: %.1f" % (metric, score))
Beispiel #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('base_directory')
    parser.add_argument('--count-params-only',
                        default=False,
                        action='store_true')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    recurrent_model = nn.GRU(input_size=2,
                             hidden_size=N_HIDDEN,
                             num_layers=1,
                             bias=True,
                             batch_first=True,
                             dropout=0.0,
                             bidirectional=False)
    recurrent_model.to(device)

    output_model = nn.Sequential(
        nn.ReLU(),
        nn.Linear(N_HIDDEN, 3, bias=True),
    )
    output_model.to(device)

    parameters = []
    for p in recurrent_model.parameters():
        parameters.append(p)

    for p in output_model.parameters():
        parameters.append(p)

    if args.count_params_only:
        n_total = 0
        for p in parameters:
            n_total += np.prod(list(p.size()))
        print('n_total', n_total)
        exit()

    optimizer = torch.optim.Adam(parameters, lr=1e-3, betas=(0.9, 0.999))

    base_directory = args.base_directory
    train_loader = get_data_loader(os.path.join(base_directory, 'train'),
                                   'RandomSampler')

    # we have to decode the individual notes for individual pieces separately, of course ...
    valid_sequences = get_dataset_individually(
        os.path.join(base_directory, 'valid'))

    valid_loaders = []
    for sequence in valid_sequences:
        loader = DataLoader(sequence,
                            batch_size=1,
                            sampler=SequentialSampler(sequence),
                            drop_last=False)
        valid_loaders.append((sequence.midifilename, loader))

    print('len(train_loader)', len(train_loader))

    log_dir = 'runs/rnn_gru_maps_spec2labels_swd'
    logger = SummaryWriter(log_dir=log_dir)

    best_f = -np.inf
    global_step = 0
    for i_epoch in range(100):
        print('i_epoch', i_epoch)
        global_step = train(logger, device, recurrent_model, output_model,
                            optimizer, train_loader, global_step)
        to_log = evaluate(logger, 'valid', device, recurrent_model,
                          output_model, valid_loaders, global_step)

        model_state = dict(recurrent_model=recurrent_model.state_dict(),
                           output_model=output_model.state_dict())
        torch.save(model_state,
                   os.path.join(log_dir, 'model_state_{}.pkl'.format(i_epoch)))

        if best_f < to_log['valid_prf/f']:
            best_f = to_log['valid_prf/f']
            torch.save(
                model_state,
                os.path.join(log_dir, 'model_state_best.pkl'.format(i_epoch)))
Beispiel #11
0
def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""):
  """Run main."""
  # Job
  jobid = flags.jobid
  num_workers = flags.num_workers
  utils.print_out("# Job id %d" % jobid)

  # GPU device
  utils.print_out(
      "# Devices visible to TensorFlow: %s" % repr(tf.Session().list_devices()))

  # Random
  random_seed = flags.random_seed
  if random_seed is not None and random_seed > 0:
    utils.print_out("# Set random seed to %d" % random_seed)
    random.seed(random_seed + jobid)
    np.random.seed(random_seed + jobid)

  # Model output directory
  out_dir = flags.out_dir
  if out_dir and not tf.gfile.Exists(out_dir):
    utils.print_out("# Creating output directory %s ..." % out_dir)
    tf.gfile.MakeDirs(out_dir)

  # Load hparams.
  loaded_hparams = False
  if flags.ckpt:  # Try to load hparams from the same directory as ckpt
    ckpt_dir = os.path.dirname(flags.ckpt)
    ckpt_hparams_file = os.path.join(ckpt_dir, "hparams")
    if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path:
      hparams = create_or_load_hparams(
          ckpt_dir, default_hparams, flags.hparams_path,
          save_hparams=False)
      loaded_hparams = True
  if not loaded_hparams:  # Try to load from out_dir
    assert out_dir
    hparams = create_or_load_hparams(
        out_dir, default_hparams, flags.hparams_path,
        save_hparams=(jobid == 0))
  
  ## Train / Decode
  if flags.inference_input_file:
    # Inference output directory
    trans_file = flags.inference_output_file
    assert trans_file
    trans_dir = os.path.dirname(trans_file)
    if not tf.gfile.Exists(trans_dir): tf.gfile.MakeDirs(trans_dir)

    # Inference indices
    hparams.inference_indices = None
    if flags.inference_list:
      (hparams.inference_indices) = (
          [int(token)  for token in flags.inference_list.split(",")])

    # Inference
    ckpt = flags.ckpt
    if not ckpt:
      ckpt = tf.train.latest_checkpoint(out_dir)
    inference_fn(ckpt, flags.inference_input_file,
                 trans_file, hparams, num_workers, jobid)

    # Evaluation
    ref_file = flags.inference_ref_file
    if ref_file and tf.gfile.Exists(trans_file):
      for metric in hparams.metrics:
        score = evaluation_utils.evaluate(
            ref_file,
            trans_file,
            metric,
            hparams.subword_option)
        utils.print_out("  %s: %.1f" % (metric, score))
  else:
    # Train
    train_fn(hparams, target_session=target_session)
Beispiel #12
0
def decode_and_evaluate(name,
                        model,
                        sess,
                        slot_trans_file,
                        intent_trans_file,
                        ref_file,
                        ref_lbl_file,
                        metrics,
                        subword_option,
                        beam_width,
                        tgt_eos,
                        task,
                        num_translations_per_input=1,
                        decode=True,
                        infer_mode="greedy"):
    """Decode a test set and compute a score according to the evaluation task."""
    # Decode
    if decode:
        utils.print_out("  decoding to output %s and %s" %
                        (slot_trans_file, intent_trans_file))

        start_time = time.time()
        num_sentences = 0
        with codecs.getwriter("utf-8")(tf.gfile.GFile(slot_trans_file,
                                                      mode="wb")) as trans_f:
            with codecs.getwriter("utf-8")(tf.gfile.GFile(
                    intent_trans_file, mode="wb")) as trans_intent_f:

                trans_f.write(
                    "")  # Write empty string to ensure file is created.

                if infer_mode == "greedy":
                    num_translations_per_input = 1
                elif infer_mode == "beam_search":
                    num_translations_per_input = min(
                        num_translations_per_input, beam_width)

                while True:
                    try:
                        if task == "joint":
                            nmt_outputs, intent_pred, src_seq_length, _ = model.decode(
                                sess)
                            if infer_mode != "beam_search":
                                nmt_outputs = np.expand_dims(nmt_outputs, 0)

                            batch_size = nmt_outputs.shape[1]
                        elif task == "intent":
                            intent_pred, _ = model.decode(sess)
                            batch_size = len(intent_pred)

                        num_sentences += batch_size

                        for sent_id in range(batch_size):
                            if task == "intent":
                                trans_intent_f.write((intent_pred[sent_id] +
                                                      b"\n").decode("utf-8"))
                            if task == "joint":
                                for beam_id in range(
                                        num_translations_per_input):
                                    translation = get_translation(
                                        nmt_outputs[beam_id],
                                        src_seq_length,
                                        sent_id,
                                        tgt_eos=tgt_eos,
                                        subword_option=subword_option)
                                    trans_f.write(
                                        (translation + b"\n").decode("utf-8"))

                    except tf.errors.OutOfRangeError:
                        utils.print_time(
                            "  done, num sentences %d, num translations per input %d"
                            % (num_sentences, num_translations_per_input),
                            start_time)
                        break

    # Evaluation
    evaluation_scores = {}
    if task == "joint":
        if ref_file and tf.gfile.Exists(slot_trans_file) and tf.gfile.Exists(
                intent_trans_file):
            for metric in metrics:
                if metric == "f1":
                    score = evaluation_utils.evaluate(
                        ref_file,
                        slot_trans_file,
                        metric,
                        subword_option=subword_option)
                    evaluation_scores[metric] = score
                    utils.print_out("  %s %s: %.1f" % (metric, name, score))
                elif metric == "accuracy":
                    score = evaluation_utils.evaluate(ref_lbl_file,
                                                      intent_trans_file,
                                                      metric)
                    evaluation_scores[metric] = score
                    utils.print_out("  %s %s: %.1f" % (metric, name, score))
    elif task == "intent":
        if ref_lbl_file and tf.gfile.Exists(intent_trans_file):
            for metric in metrics:
                score = evaluation_utils.evaluate(ref_lbl_file,
                                                  intent_trans_file, metric)
                evaluation_scores[metric] = score
                utils.print_out("  %s %s: %.1f" % (metric, name, score))

    return evaluation_scores
Beispiel #13
0
                    feed_dict={
                        encoder_inputs_x: encoder_inputs_dev,
                        decoder_inputs_y1: decoder_inputs_dev,
                        decoder_outputs_y2: decoder_outputs_dev,
                        source_sequence_length: encoder_length_dev,
                        target_sequence_length: decoder_length_dev
                    })
                ground_truth = id2words(decoder_outputs_batch)
                sentences = id2words(dev_out)
                for line1 in ground_truth:
                    f1.write(line1 + '\n')
                for line2 in sentences:
                    f2.write(line2 + '\n')
            f1.close()
            f2.close()
            bleu_score = evaluate(ground_truth_file, trans_file, 'bleu')
            print(bleu_score)
            # print(evaluate)
            print("eval" + str(epoch) + ':' + str(j) + str(loss))

    saver.save(sess, checkpoint)
    print('Model Trained and Saved')

# with tf.Session() as sess:
#     saver.restore(sess, checkpoint)
#     words = [10,22,34]
#     inputs = np.tile(words,(batch_size,1))
#     acc_test = predicting_logits.eval(feed_dict={encoder_inputs_x: inputs,
#                                                  source_sequence_length: [len(words)]*batch_size, target_sequence_length: [len(words)]*batch_size})[0]
#     print(acc_test)
def main(flags, unused_argv):
    # Get the random seed and seed numpy and random with it
    random_seed = flags.random_seed
    if random_seed is not None and random_seed > 0:
        if flags.verbose:
            utils.print_out("# Set random seed to %d" % random_seed)
        random.seed(random_seed)
        np.random.seed(random_seed)

    # Create the output directory
    out_dir = flags.out_dir
    if not tf.gfile.Exists(out_dir):
        tf.gfile.MakeDirs(out_dir)

    # Load the hyperparameters
    default_hparams = argument_parser.create_hparams(flags)
    hparams = argument_parser.create_or_load_hparams(out_dir, default_hparams, flags)


    # The place where we decide if we train or if we do inference
    # ToDo: Add ability to chat based on the chat argument
    if flags.chat:
        chat_logs_output_file = flags.chat_logs_output_file
        ckpt = flags.ckpt
        if not ckpt:
            # If a checkpoint has not been provided then load the latest one
            ckpt = tf.train.latest_checkpoint(out_dir)
        # Initiate chat mode
        inference.chat(checkpoint=ckpt, chat_logs_output_file=chat_logs_output_file, hparams=hparams)

    elif flags.inference_input_file:
        # Inference indices
        hparams.inference_indices = None
        if flags.inference_list:
            (hparams.inference_indices) = (
                [int(token) for token in flags.inference_list.split(",")])

        # Inference
        inference_output_file = flags.inference_output_file
        ckpt = flags.ckpt
        if not ckpt:
            # If a checkpoint has not been provided then load the latest one
            ckpt = tf.train.latest_checkpoint(out_dir)
        # Get responses to the utterances and write them to file
        inference.inference(ckpt, flags.inference_input_file, inference_output_file, hparams)

        # Compute scores for the reference file
        ref_file = flags.inference_ref_file
        if ref_file and tf.gfile.Exists(inference_output_file):
            for metric in hparams.metrics:
                score = evaluation_utils.evaluate(
                    ref_file,
                    inference_output_file,
                    metric,
                    hparams.bpe_delimiter)
                if flags.verbose:
                    utils.print_out("  %s: %.1f" % (metric, score))

    else:
        # Start training
        train.train(hparams)
def infer_main(flags, default_hparams, inference_fn, target_session=""):
  """Run main."""
  # Job
  jobid = flags.jobid
  num_workers = flags.num_workers
  utils.print_out("# Job id %d" % jobid)

  # Random
  random_seed = flags.random_seed
  if random_seed is not None and random_seed > 0:
    utils.print_out("# Set random seed to %d" % random_seed)
    random.seed(random_seed + jobid)
    np.random.seed(random_seed + jobid)

  ## Train / Decode
  #root = default_hparams.out_model_info.split('/')
  #parent_path = ''
  #for i in range(len(root)-1):
  #    parent_path += root[i] + '/'
  
  out_model_file = flags.out_model_info
  infile = open ( out_model_file, 'r')
  out_dir = ""
  for line in infile:
      out_dir = line.rstrip( '\r\n' )
      break

  flags.out_dir = out_dir #parent_path + flags.out_dir
  
  default_hparams.out_dir = out_dir

      
  
  hparams = load_hparams_Alveo(out_dir, default_hparams, flags.hparams_path, save_hparams=True)
  #hparams.out_dir = out_dir
  #
  
  print(hparams.num_units)
  print(hparams.dropout)
  print(hparams.attention)
  print(hparams.train_src)
  print(hparams.dev_src)

  
  print(hparams.out_model_info)
  print(hparams.out_dir)

  
  hparams.inference_indices = None
  if flags.inference_list:
    (hparams.inference_indices) = (
        [int(token) for token in flags.inference_list.split(",")])

  # Inference
  trans_file = flags.inference_output_file
  ckpt = flags.ckpt
  if not ckpt:
    ckpt = tf.train.latest_checkpoint(out_dir)
  inference_fn(ckpt, flags.inference_input_file,
               trans_file, hparams, num_workers, jobid)

  # Evaluation
  ref_file = flags.inference_ref_file
  if ref_file and tf.gfile.Exists(trans_file):
    for metric in hparams.metrics:
      score = evaluation_utils.evaluate(
          ref_file,
          trans_file,
          metric,
          hparams.subword_option)
      utils.print_out("  %s: %.1f" % (metric, score))