reranker.train(text_seqs, da_seqs, scores, log_probs, cfg["epoch"],
                       cfg["valid_size"], cfg["num_ranks"],
                       cfg.get("only_bottom",
                               False), cfg.get("only_top", False),
                       cfg.get("min_training_passes", 5))

if cfg["show_reranker_post_training_stats"]:
    test_das = get_test_das()
    test_texts = get_true_sents()
    final_beam_path = TEST_BEAM_SAVE_FORMAT.format(10)

    if not os.path.exists(final_beam_path):
        print("Creating final beams file")
        models = TGEN_Model(da_embedder, text_embedder,
                            cfg['tgen_seq2seq_config'])
        models.load_models()
        scorer = get_score_function('identity', cfg, models, None, 10)
        run_beam_search_with_rescorer(scorer,
                                      models,
                                      test_das,
                                      10,
                                      only_rerank_final=True,
                                      save_final_beam_path=final_beam_path)

    bleu = BLEUScore()
    test_da_embs = da_embedder.get_embeddings(test_das)
    final_beam = pickle.load(open(final_beam_path, 'rb+'))
    all_reals = []
    all_preds = []
    for da_emb, beam, true in zip(test_da_embs, final_beam, test_texts):
        real_scores = []
def get_scores_ordered_beam(cfg,
                            da_embedder,
                            text_embedder,
                            beam_save_path=None):
    print("Loading Training Data")
    beam_size = cfg["beam_size"]
    train_texts, train_das = get_multi_reference_training_variables()
    if beam_save_path is None:
        beam_save_path = TRAIN_BEAM_SAVE_FORMAT.format(
            beam_size, cfg["tgen_seq2seq_config"].split('.')[0].split('/')[-1])
    if not os.path.exists(beam_save_path):
        models = TGEN_Model(da_embedder, text_embedder,
                            cfg["tgen_seq2seq_config"])
        models.load_models()
        print("Creating test final beams")
        scorer = get_score_function('identity', cfg, models, None, beam_size)
        run_beam_search_with_rescorer(scorer,
                                      models,
                                      train_das,
                                      beam_size,
                                      cfg,
                                      only_rerank_final=True,
                                      save_final_beam_path=beam_save_path)
    bleu = BLEUScore()
    final_beam = pickle.load(open(beam_save_path, "rb"))
    text_seqs = []
    da_seqs = []
    scores = []
    log_probs = []
    with_ref_train_flag = cfg["with_refs_train"]
    num_ranks = cfg["num_ranks"]
    cut_offs = get_section_cutoffs(num_ranks)
    regression_vals = get_regression_vals(num_ranks, with_ref_train_flag)
    if cfg["output_type"] != 'pair':
        print("Cut off values:", cut_offs)
        print("Regression vals:", regression_vals)

    only_top = cfg.get("only_top", False)
    only_bottom = cfg.get("only_bottom", False)
    merge_middles = cfg["merge_middle_sections"]
    if only_top:
        print("Only using top value")
    if merge_middles and only_top:
        print("Ignoring only top since have merge_middle_sections set")
    training_vals = list(zip(final_beam, train_texts, train_das))
    training_vals = training_vals[:cfg.get("use_size", len(training_vals))]
    for beam, real_texts, da in tqdm(training_vals):
        beam_scores = []
        if with_ref_train_flag:
            # I am not sure how to do log probs?
            text_seqs.extend(real_texts)
            da_seqs.extend([da for _ in real_texts])
            scores.extend([0 for _ in real_texts])

        for i, path in enumerate(beam):
            bleu.reset()
            hyp = [
                x for x in text_embedder.reverse_embedding(path[1])
                if x not in [START_TOK, END_TOK, PAD_TOK]
            ]
            bleu.append(
                hyp, [x for x in real_texts if x not in [START_TOK, END_TOK]])
            beam_scores.append((bleu.score(), hyp, path))

            # log_probs.append(i)

        for i, (score, hyp,
                path) in enumerate(sorted(beam_scores, reverse=True)):
            text_seqs.append([START_TOK] + hyp + [END_TOK])
            da_seqs.append(da)
            if cfg["output_type"] in ['bleu', 'pair']:
                scores.append(score)
            elif cfg["output_type"] == 'order_discrete':
                scores.append(to_categorical([i], num_classes=beam_size))
            elif cfg["output_type"] in [
                    'regression_ranker', 'regression_reranker_relative'
            ]:
                scores.append(i / (beam_size - 1))
            elif cfg["output_type"] in [
                    'regression_sections', 'binary_classif'
            ]:
                val = (i / (beam_size - 1))
                regression_val = get_section_value(val, cut_offs,
                                                   regression_vals,
                                                   merge_middles, only_top,
                                                   only_bottom)
                scores.append(
                    regression_val
                )  # converts range from [0,1] to [-1,1] (which has mean of 0)
            else:
                raise ValueError("Unknown output type")

            log_probs.append([path[0]])

    text_seqs = np.array(
        text_embedder.get_embeddings(text_seqs, pad_from_end=False))
    da_seqs = np.array(da_embedder.get_embeddings(da_seqs))

    if cfg["output_type"] in [
            'regression_ranker', 'bleu', 'regression_reranker_relative',
            'pair', 'regression_sections', 'binary_classif'
    ]:
        # print("SCORES: ", Counter(scores))
        scores = np.array(scores).reshape((-1, 1))
    elif cfg["output_type"] == 'order_discrete':
        scores = np.array(scores).reshape((-1, beam_size))

    log_probs = np.array(log_probs)
    return text_seqs, da_seqs, scores, log_probs
if cfg_path is None:
    filenames = os.listdir(CONFIGS_MODEL_DIR)
    filepaths = [
        os.path.join(CONFIGS_MODEL_DIR, filename) for filename in filenames
    ]
    mod_times = [(os.path.getmtime(x), i) for i, x in enumerate(filepaths)]
    cfg_path = filepaths[max(mod_times)[1]]

cfg = yaml.safe_load(open(cfg_path, 'r'))
texts, das = get_training_variables()
text_embedder = TokEmbeddingSeq2SeqExtractor(texts)
da_embedder = DAEmbeddingSeq2SeqExtractor(das)

texts_mr, da_mr = get_multi_reference_training_variables()
# train_text = np.array(text_embedder.get_embeddings(texts, pad_from_end=True) + [text_embedder.empty_embedding])
# da_embs = da_embedder.get_embeddings(das) + [da_embedder.empty_embedding]

seq2seq = TGEN_Model(da_embedder, text_embedder, cfg_path)
seq2seq.load_models()
seq2seq.full_model.summary()
if "use_prop" in cfg:
    da_mr = da_mr[:int(len(da_mr) * cfg['use_prop'])]
    texts_mr = texts_mr[:int(len(da_mr) * cfg['use_prop'])]
seq2seq.train(da_seq=da_mr,
              text_seq=texts_mr,
              n_epochs=cfg["epoch"],
              valid_size=cfg["valid_size"],
              early_stop_point=cfg["min_epoch"],
              minimum_stop_point=20,
              multi_ref=True)