Exemple #1
0
    def test_multibatch(self):
        Vi, Ei, Hi = 12, 17, 7
        enc = EncoderNaive(Vi, Ei, Hi)

        raw_seq1 = [2, 5, 0, 3]
        raw_seq2 = [2, 5, 4, 3, 0, 0, 1, 11, 3]
        raw_seq3 = [2, 5, 4, 3, 0, 11, 3]
        raw_seq4 = [5, 3, 0, 0, 1, 11, 3]

        src_data = [raw_seq1, raw_seq2, raw_seq3, raw_seq4]
        src_batch, src_mask = utils.make_batch_src(src_data)
        fb = enc(src_batch, src_mask)

        for i in range(len(src_data)):
            raw_s = src_data[i]
            input_seq = [
                Variable(np.array([v], dtype=np.int32)) for v in raw_s
            ]
            fb_naive = enc.naive_call(input_seq, None)
            for j in range(len(raw_s)):
                print("maxdiff:",
                      np.max(np.abs(fb.data[i][j] - fb_naive[j].data[0])))
                assert np.allclose(fb.data[i][j],
                                   fb_naive[j].data[0],
                                   atol=1e-6)
Exemple #2
0
    def test_multibatch(self):
        Vi, Ei, Hi = 12, 17, 7
        enc = EncoderNaive(Vi, Ei, Hi)

        Hi_a, Ha, Ho = 2 * Hi, 19, 23
        attn_model = AttentionModuleNaive(Hi_a, Ha, Ho)

        raw_seq1 = [2, 5, 0, 3]
        raw_seq2 = [2, 5, 4, 3, 0, 0, 1, 11, 3]
        raw_seq3 = [2, 5, 4, 3, 0, 11, 3]
        raw_seq4 = [5, 3, 0, 0, 1, 11, 3]

        src_data = [raw_seq1, raw_seq2, raw_seq3, raw_seq4]
        src_batch, src_mask = utils.make_batch_src(src_data)
        fb = enc(src_batch, src_mask)
        compute_ctxt = attn_model(fb, src_mask)

        state_raw = np.random.randn(4, Ho).astype(np.float32)
        state = Variable(state_raw)

        ci, attn = compute_ctxt(state)

        for i in range(len(src_data)):
            raw_s = src_data[i]
            input_seq = [
                Variable(np.array([v], dtype=np.int32)) for v in raw_s
            ]
            fb_naive = enc.naive_call(input_seq, None)
            compute_ctxt_naive = attn_model.naive_call(fb_naive, None)
            ci_naive, attn_naive = compute_ctxt_naive(
                Variable(state_raw[i].reshape(1, -1)))
            print("maxdiff ci:", np.max(np.abs(ci.data[i] - ci_naive.data[0])))
            assert np.allclose(ci.data[i], ci_naive.data[0], atol=1e-6)
            #             print(attn.data.shape, attn_naive.data.shape)
            print(
                "maxdiff attn:",
                np.max(np.abs(attn.data[i][:len(raw_s)] - attn_naive.data[0])))
            assert np.allclose(attn.data[i][:len(raw_s)],
                               attn_naive.data[0],
                               atol=1e-6)
            assert np.all(attn.data[i][len(raw_s):] == 0)
Exemple #3
0
def do_eval(config_eval):
    src_fn = config_eval.process.src_fn
    tgt_fn = config_eval.output.tgt_fn
    mode = config_eval.method.mode
    gpu = config_eval.process.gpu
    dest_fn = config_eval.process.dest_fn
    mb_size = config_eval.process.mb_size
    nb_steps = config_eval.method.nb_steps
    nb_steps_ratio = config_eval.method.nb_steps_ratio
    max_nb_ex = config_eval.process.max_nb_ex
    nbest_to_rescore = config_eval.output.nbest_to_rescore
    nbest = config_eval.output.nbest

    beam_width = config_eval.method.beam_width
    beam_pruning_margin = config_eval.method.beam_pruning_margin
    beam_score_length_normalization = config_eval.method.beam_score_length_normalization
    beam_score_length_normalization_strength = config_eval.method.beam_score_length_normalization_strength
    beam_score_coverage_penalty = config_eval.beam_score_coverage_penalty
    beam_score_coverage_penalty_strength = config_eval.beam_score_coverage_penalty_strength
    always_consider_eos_and_placeholders = config_eval.method.always_consider_eos_and_placeholders

    if config_eval.process.force_placeholders:
        # making it  default for now
        always_consider_eos_and_placeholders = True

    post_score_length_normalization = config_eval.method.post_score_length_normalization
    post_score_length_normalization_strength = config_eval.method.post_score_length_normalization_strength
    groundhog = config_eval.method.groundhog
    tgt_unk_id = config_eval.output.tgt_unk_id
    force_finish = config_eval.method.force_finish
    prob_space_combination = config_eval.method.prob_space_combination
    generate_attention_html = config_eval.output.generate_attention_html
    rich_output_filename = config_eval.output.rich_output_filename

    ref = config_eval.output.ref
    dic = config_eval.output.dic
    normalize_unicode_unk = config_eval.output.normalize_unicode_unk
    attempt_to_relocate_unk_source = config_eval.output.attempt_to_relocate_unk_source
    remove_unk = config_eval.output.remove_unk

    post_score_coverage_penalty = config_eval.method.post_score_coverage_penalty
    post_score_coverage_penalty_strength = config_eval.method.post_score_coverage_penalty_strength

    time_start = time.perf_counter()

    astar_params = beam_search.AStarParams(
        astar_batch_size=config_eval.method.astar_batch_size,
        astar_max_queue_size=config_eval.method.astar_max_queue_size,
        astar_prune_margin=config_eval.method.astar_prune_margin,
        astar_prune_ratio=config_eval.method.astar_prune_ratio,
        length_normalization_exponent=config_eval.method.
        astar_length_normalization_exponent,
        length_normalization_constant=config_eval.method.
        astar_length_normalization_constant,
        astar_priority_eval_string=config_eval.method.
        astar_priority_eval_string,
        max_length_diff=config_eval.method.astar_max_length_diff)

    make_constraints_dict = None

    if config_eval.process.server is None:
        encdec_list, eos_idx, src_indexer, tgt_indexer, reverse_encdec, model_infos_list = create_encdec(
            config_eval)

        eval_dir_placeholder = "@eval@/"
        if dest_fn.startswith(eval_dir_placeholder):
            if config_eval.trained_model is not None:
                training_model_filename = config_eval.trained_model
            else:
                if len(config_eval.process.load_model_config) == 0:
                    log.error("Cannot detect value for $eval$ placeholder")
                    sys.exit(1)
                training_model_filename = config_eval.process.load_model_config[
                    0]

            eval_dir = os.path.join(os.path.dirname(training_model_filename),
                                    "eval")
            dest_fn = os.path.join(eval_dir,
                                   dest_fn[len(eval_dir_placeholder):])
            log.info("$eval$ detected. dest_fn is: %s ", dest_fn)
            ensure_path(eval_dir)

        if src_fn is None:
            (dev_src_from_config, dev_tgt_from_config, test_src_from_config,
             test_tgt_from_config
             ) = get_src_tgt_dev_from_config_eval(config_eval)
            if test_src_from_config is None:
                log.error(
                    "Could not find value for source text, either on command line or in config files"
                )
                sys.exit(1)
            log.info("using files from config as src:%s", test_src_from_config)
            src_fn = test_src_from_config
            if ref is None:
                log.info("using files from config as ref:%s",
                         test_tgt_from_config)
                ref = test_tgt_from_config

        if config_eval.process.force_placeholders:
            if make_constraints_dict is None:
                make_constraints_dict = {}
            make_constraints_dict[
                "ph_constraint"] = placeholder_constraints_builder(
                    src_indexer,
                    tgt_indexer,
                    units_placeholders=config_eval.process.units_placeholders)

        if config_eval.process.bilingual_dic_for_reranking:
            if make_constraints_dict is None:
                make_constraints_dict = {}

            print("**making ja en dic")
            ja_en_search, en_ja_search = dictionnary_handling.load_search_trie(
                config_eval.process.bilingual_dic_for_reranking,
                config_eval.process.invert_bilingual_dic_for_reranking)

            print("**define constraints")
            make_constraints_dict[
                "dic_constraint"] = dictionnary_handling.make_constraint(
                    ja_en_search, en_ja_search, tgt_indexer)

        elif False:

            re_word = re.compile(r"[A-Za-z]+")
            re_digits = re.compile(r"\d+")

            def unsegment(s):
                res = []
                for w in s.split(" "):
                    if w.startswith("▁"):
                        w = " " + w[1:]
                    res.append(w)
                return "".join(res)

            def make_constraints(src, src_seq):
                line_src = unsegment(src)
                line_src = unicodedata.normalize('NFKC', line_src)
                word_list = [
                    word for word in re_word.findall(line_src) if len(word) > 3
                ]
                digit_list = [
                    digit for digit in re_digits.findall(line_src)
                    if len(digit) > 2
                ]
                if len(word_list) == 0 and len(digit_list) == 0:

                    def constraint_fn(tgt_seq):
                        return 1
                else:

                    def constraint_fn(tgt_seq):
                        tgt = tgt_indexer.deconvert(tgt_seq)
                        line_tgt = unsegment(tgt)
                        line_tgt = unicodedata.normalize('NFKC', line_tgt)
                        matched_word = 0
                        for word in word_list:
                            if word in line_ref:
                                matched_word += 1

                        matched_digit = 0
                        for digit in digit_list:
                            if digit in line_ref:
                                matched_digit += 1

                        if matched_word == len(
                                word_list) and matched_digit == len(
                                    digit_list):
                            return 1
                        else:
                            return (matched_word + matched_digit) / (
                                len(word_list) + len(digit_list))

                    return constraint_fn

        else:
            make_constraints_dict = None

        log.info("opening source file %s" % src_fn)

        preprocessed_input = build_dataset_one_side_pp(
            src_fn,
            src_pp=src_indexer,
            max_nb_ex=max_nb_ex,
            make_constraints_dict=make_constraints_dict)

        if make_constraints_dict is not None:
            src_data, stats_src_pp, constraints_list = preprocessed_input
        else:
            src_data, stats_src_pp = preprocessed_input
            constraints_list = None
        log.info("src data stats:\n%s", stats_src_pp.make_report())

        translation_infos = OrderedNamespace()
        translation_infos["src"] = src_fn
        translation_infos["tgt"] = tgt_fn
        translation_infos["ref"] = ref

        for num_model, model_infos in enumerate(model_infos_list):
            translation_infos["model%i" % num_model] = model_infos

    if dest_fn is not None:
        save_eval_config_fn = dest_fn + ".eval.init.config.json"
        log.info("Saving initial eval config to %s" % save_eval_config_fn)
        config_eval.save_to(save_eval_config_fn)

#     log.info("%i sentences loaded" % make_data_infos.nb_ex)
#     log.info("#tokens src: %i   of which %i (%f%%) are unknown"%(make_data_infos.total_token,
#                                                                  make_data_infos.total_count_unk,
#                                                                  float(make_data_infos.total_count_unk * 100) /
#                                                                     make_data_infos.total_token))

    tgt_data = None
    if tgt_fn is not None:
        log.info("opening target file %s" % tgt_fn)
        tgt_data, stats_tgt_pp = build_dataset_one_side_pp(tgt_fn,
                                                           src_pp=tgt_indexer,
                                                           max_nb_ex=max_nb_ex)
        log.info("tgt data stats:\n%s", stats_tgt_pp.make_report())
#         log.info("%i sentences loaded"%make_data_infos.nb_ex)
#         log.info("#tokens src: %i   of which %i (%f%%) are unknown"%(make_data_infos.total_token,
#                                                                  make_data_infos.total_count_unk,
#                                                                  float(make_data_infos.total_count_unk * 100) /
#                                                                     make_data_infos.total_token))

#     translations = greedy_batch_translate(encdec, eos_idx, src_data, batch_size = mb_size, gpu = args.gpu)

    time_all_loaded = time.perf_counter()

    if mode == "translate":
        log.info("writing translation of to %s" % dest_fn)
        with cuda.get_device_from_id(gpu):
            assert len(encdec_list) == 1
            translations = greedy_batch_translate(
                encdec_list[0],
                eos_idx,
                src_data,
                batch_size=mb_size,
                gpu=gpu,
                nb_steps=nb_steps,
                use_chainerx=config_eval.process.use_chainerx)
        out = io.open(dest_fn, "wt", encoding="utf8")
        for t in translations:
            if t[-1] == eos_idx:
                t = t[:-1]
            ct = tgt_indexer.deconvert(t, unk_tag="#T_UNK#")
            #             ct = convert_idx_to_string(t, tgt_voc + ["#T_UNK#"])
            out.write(ct + "\n")

    elif mode == "beam_search" or mode == "eval_bleu" or mode == "astar_search" or mode == "astar_eval_bleu":
        if config_eval.process.server is not None:
            from nmt_chainer.translation.server import do_start_server
            do_start_server(config_eval)
        else:

            def translate_closure(beam_width, nb_steps_ratio):
                beam_search_params = beam_search.BeamSearchParams(
                    beam_width=beam_width,
                    beam_pruning_margin=beam_pruning_margin,
                    beam_score_coverage_penalty=beam_score_coverage_penalty,
                    beam_score_coverage_penalty_strength=
                    beam_score_coverage_penalty_strength,
                    beam_score_length_normalization=
                    beam_score_length_normalization,
                    beam_score_length_normalization_strength=
                    beam_score_length_normalization_strength,
                    force_finish=force_finish,
                    use_unfinished_translation_if_none_found=True,
                    always_consider_eos_and_placeholders=
                    always_consider_eos_and_placeholders)

                translate_to_file_with_beam_search(
                    dest_fn,
                    gpu,
                    encdec_list,
                    eos_idx,
                    src_data,
                    beam_search_params=beam_search_params,
                    nb_steps=nb_steps,
                    nb_steps_ratio=nb_steps_ratio,
                    post_score_length_normalization=
                    post_score_length_normalization,
                    post_score_length_normalization_strength=
                    post_score_length_normalization_strength,
                    post_score_coverage_penalty=post_score_coverage_penalty,
                    post_score_coverage_penalty_strength=
                    post_score_coverage_penalty_strength,
                    groundhog=groundhog,
                    tgt_unk_id=tgt_unk_id,
                    tgt_indexer=tgt_indexer,
                    prob_space_combination=prob_space_combination,
                    reverse_encdec=reverse_encdec,
                    generate_attention_html=generate_attention_html,
                    src_indexer=src_indexer,
                    rich_output_filename=rich_output_filename,
                    unprocessed_output_filename=dest_fn + ".unprocessed",
                    nbest=nbest,
                    constraints_fn_list=constraints_list,
                    use_astar=(mode == "astar_search"
                               or mode == "astar_eval_bleu"),
                    astar_params=astar_params,
                    use_chainerx=config_eval.process.use_chainerx)

                translation_infos["dest"] = dest_fn
                translation_infos["unprocessed"] = dest_fn + ".unprocessed"
                if mode == "eval_bleu" or mode == "astar_eval_bleu":
                    if ref is not None:
                        bc = bleu_computer.get_bc_from_files(ref, dest_fn)
                        print("bleu before unk replace:", bc)
                        translation_infos["bleu"] = bc.bleu()
                        translation_infos["bleu_infos"] = str(bc)
                    else:
                        print("bleu before unk replace: No Ref Provided")

                    from nmt_chainer.utilities import replace_tgt_unk
                    replace_tgt_unk.replace_unk(
                        dest_fn, src_fn, dest_fn + ".unk_replaced", dic,
                        remove_unk, normalize_unicode_unk,
                        attempt_to_relocate_unk_source)
                    translation_infos[
                        "unk_replaced"] = dest_fn + ".unk_replaced"

                    if ref is not None:
                        bc = bleu_computer.get_bc_from_files(
                            ref, dest_fn + ".unk_replaced")
                        print("bleu after unk replace:", bc)
                        translation_infos["post_unk_bleu"] = bc.bleu()
                        translation_infos["post_unk_bleu_infos"] = str(bc)
                    else:
                        print("bleu before unk replace: No Ref Provided")
                    return -bc.bleu()
                else:
                    return None

            if config_eval.process.do_hyper_param_search is not None:
                study_filename, study_name, n_trials = do_hyper_param_search
                n_trials = int(n_trials)
                import optuna

                def objective(trial):
                    nb_steps_ratio = trial.suggest_uniform(
                        'nb_steps_ratio', 0.9, 3.5)
                    beam_width = trial.suggest_int("beam_width", 2, 50)
                    return translate_closure(beam_width, nb_steps_ratio)

                study = optuna.create_study(study_name=study_name,
                                            storage="sqlite:///" +
                                            study_filename)
                study.optimize(objective, n_trials=n_trials)
                print(study.best_params)
                print(study.best_value)
                print(study.best_trial)

            else:  # hyperparams optim
                translate_closure(beam_width, nb_steps_ratio)

    elif mode == "translate_attn":
        log.info("writing translation + attention as html to %s" % dest_fn)
        with cuda.get_device_from_id(gpu):
            assert len(encdec_list) == 1
            translations, attn_all = greedy_batch_translate(
                encdec_list[0],
                eos_idx,
                src_data,
                batch_size=mb_size,
                gpu=gpu,
                get_attention=True,
                nb_steps=nb_steps,
                use_chainerx=config_eval.process.use_chainerx)
#         tgt_voc_with_unk = tgt_voc + ["#T_UNK#"]
#         src_voc_with_unk = src_voc + ["#S_UNK#"]
        assert len(translations) == len(src_data)
        assert len(attn_all) == len(src_data)
        attn_vis = AttentionVisualizer()
        for num_t in six.moves.range(len(src_data)):
            src_idx_list = src_data[num_t]
            tgt_idx_list = translations[num_t][:-1]
            attn = attn_all[num_t]
            #             assert len(attn) == len(tgt_idx_list)

            src_w = src_indexer.deconvert_swallow(
                src_idx_list, unk_tag="#S_UNK#") + ["SUM_ATTN"]
            tgt_w = tgt_indexer.deconvert_swallow(tgt_idx_list,
                                                  unk_tag="#T_UNK#")
            #             src_w = [src_voc_with_unk[idx] for idx in src_idx_list] + ["SUM_ATTN"]
            #             tgt_w = [tgt_voc_with_unk[idx] for idx in tgt_idx_list]
            #             for j in six.moves.range(len(tgt_idx_list)):
            #                 tgt_idx_list.append(tgt_voc_with_unk[t_and_attn[j][0]])
            #
            #         print([src_voc_with_unk[idx] for idx in src_idx_list], tgt_idx_list)

            attn_vis.add_plot(src_w, tgt_w, attn)

        attn_vis.make_plot(dest_fn)

    elif mode == "align":
        import nmt_chainer.utilities.visualisation as visualisation
        assert tgt_data is not None
        assert len(tgt_data) == len(src_data)
        log.info("writing alignment as html to %s" % dest_fn)
        with cuda.get_device_from_id(gpu):
            assert len(encdec_list) == 1
            loss, attn_all = batch_align(
                encdec_list[0],
                eos_idx,
                list(six.moves.zip(src_data, tgt_data)),
                batch_size=mb_size,
                gpu=gpu,
                use_chainerx=config_eval.process.use_chainerx)
#         tgt_voc_with_unk = tgt_voc + ["#T_UNK#"]
#         src_voc_with_unk = src_voc + ["#S_UNK#"]

        assert len(attn_all) == len(src_data)
        plots_list = []
        for num_t in six.moves.range(len(src_data)):
            src_idx_list = src_data[num_t]
            tgt_idx_list = tgt_data[num_t]
            attn = attn_all[num_t]
            #             assert len(attn) == len(tgt_idx_list)

            alignment = np.zeros((len(src_idx_list) + 1, len(tgt_idx_list)))
            sum_al = [0] * len(tgt_idx_list)
            for i in six.moves.range(len(src_idx_list)):
                for j in six.moves.range(len(tgt_idx_list)):
                    alignment[i, j] = attn[j][i]
                    sum_al[j] += alignment[i, j]
            for j in six.moves.range(len(tgt_idx_list)):
                alignment[len(src_idx_list), j] = sum_al[j]

            src_w = src_indexer.deconvert_swallow(
                src_idx_list, unk_tag="#S_UNK#") + ["SUM_ATTN"]
            tgt_w = tgt_indexer.deconvert_swallow(tgt_idx_list,
                                                  unk_tag="#T_UNK#")
            #             src_w = [src_voc_with_unk[idx] for idx in src_idx_list] + ["SUM_ATTN"]
            #             tgt_w = [tgt_voc_with_unk[idx] for idx in tgt_idx_list]
            #             for j in six.moves.range(len(tgt_idx_list)):
            #                 tgt_idx_list.append(tgt_voc_with_unk[t_and_attn[j][0]])
            #
            #         print([src_voc_with_unk[idx] for idx in src_idx_list], tgt_idx_list)
            p1 = visualisation.make_alignment_figure(src_w, tgt_w, alignment)
            #             p2 = visualisation.make_alignment_figure(
            #                             [src_voc_with_unk[idx] for idx in src_idx_list], tgt_idx_list, alignment)
            plots_list.append(p1)
        p_all = visualisation.Column(*plots_list)
        visualisation.output_file(dest_fn)
        visualisation.show(p_all)
#     for t in translations_with_attn:
#         for x, attn in t:
#             print(x, attn)

#         out.write(convert_idx_to_string([x for x, attn in t], tgt_voc + ["#T_UNK#"]) + "\n")

    elif mode == "score_nbest":
        log.info("opening nbest file %s" % nbest_to_rescore)
        nbest_f = io.open(nbest_to_rescore, 'rt', encoding="utf8")
        nbest_list = [[]]
        for line in nbest_f:
            line = line.strip().split("|||")
            num_src = int(line[0].strip())
            if num_src >= len(nbest_list):
                assert num_src == len(nbest_list)
                if max_nb_ex is not None and num_src >= max_nb_ex:
                    break
                nbest_list.append([])
            else:
                assert num_src == len(nbest_list) - 1
            sentence = line[1].strip()
            nbest_list[-1].append(sentence.split(" "))

        log.info("found nbest lists for %i source sentences" % len(nbest_list))
        nbest_converted, make_data_infos = make_data.build_dataset_for_nbest_list_scoring(
            tgt_indexer, nbest_list)
        log.info("total %i sentences loaded" % make_data_infos.nb_ex)
        log.info("#tokens src: %i   of which %i (%f%%) are unknown" %
                 (make_data_infos.total_token, make_data_infos.total_count_unk,
                  float(make_data_infos.total_count_unk * 100) /
                  make_data_infos.total_token))
        if len(nbest_list) != len(src_data[:max_nb_ex]):
            log.warn("mismatch in lengths nbest vs src : %i != %i" %
                     (len(nbest_list), len(src_data[:max_nb_ex])))
            assert len(nbest_list) == len(src_data[:max_nb_ex])

        log.info("starting scoring")
        from nmt_chainer.utilities import utils
        res = []
        for num in six.moves.range(len(nbest_converted)):
            if num % 200 == 0:
                print(num, file=sys.stderr)
            elif num % 50 == 0:
                print("*", file=sys.stderr)

            res.append([])
            src, tgt_list = src_data[num], nbest_converted[num]
            src_batch, src_mask = utils.make_batch_src([src],
                                                       gpu=gpu,
                                                       volatile="on")

            assert len(encdec_list) == 1
            scorer = encdec_list[0].nbest_scorer(src_batch, src_mask)

            nb_batches = (len(tgt_list) + mb_size - 1) // mb_size
            for num_batch in six.moves.range(nb_batches):
                tgt_batch, arg_sort = utils.make_batch_tgt(
                    tgt_list[num_batch * nb_batches:(num_batch + 1) *
                             nb_batches],
                    eos_idx=eos_idx,
                    gpu=gpu,
                    volatile="on",
                    need_arg_sort=True)
                scores, attn = scorer(tgt_batch)
                scores, _ = scores
                scores = scores.data

                assert len(arg_sort) == len(scores)
                de_sorted_scores = [None] * len(scores)
                for xpos in six.moves.range(len(arg_sort)):
                    original_pos = arg_sort[xpos]
                    de_sorted_scores[original_pos] = scores[xpos]
                res[-1] += de_sorted_scores
        print('', file=sys.stderr)
        log.info("writing scores to %s" % dest_fn)
        out = io.open(dest_fn, "wt", encoding="utf8")
        for num in six.moves.range(len(res)):
            for score in res[num]:
                out.write("%i %f\n" % (num, score))

    time_end = time.perf_counter()
    translation_infos["loading_time"] = time_all_loaded - time_start
    translation_infos["translation_time"] = time_end - time_all_loaded
    translation_infos["total_time"] = time_end - time_start
    if dest_fn is not None:
        config_eval_session = config_eval.copy(readonly=False)
        config_eval_session.add_section("translation_infos",
                                        keep_at_bottom="metadata")
        config_eval_session["translation_infos"] = translation_infos
        config_eval_session.set_metadata_modified_time()
        save_eval_config_fn = dest_fn + ".eval.config.json"
        log.info("Saving eval config to %s" % save_eval_config_fn)
        config_eval_session.save_to(save_eval_config_fn)
Exemple #4
0
def beam_search_translate(
        encdec,
        eos_idx,
        src_data,
        beam_search_params: beam_search.BeamSearchParams = beam_search.
    BeamSearchParams(),
        #beam_width=20, beam_pruning_margin=None,
        nb_steps=50,
        gpu=None,

        #beam_score_coverage_penalty=None, beam_score_coverage_penalty_strength=0.2,
        need_attention=False,
        nb_steps_ratio=None,
        #beam_score_length_normalization='none', beam_score_length_normalization_strength=0.2,
        post_score_length_normalization='simple',
        post_score_length_normalization_strength=0.2,
        post_score_coverage_penalty='none',
        post_score_coverage_penalty_strength=0.2,
        groundhog=False,
        #force_finish=False,
        prob_space_combination=False,
        reverse_encdec=None,
        #use_unfinished_translation_if_none_found=False,
        nbest=None,
        constraints_fn_list: Optional[List[
            beam_search.BeamSearchConstraints]] = None,
        use_astar=False,
        astar_params: beam_search.AStarParams = beam_search.AStarParams(),
        use_chainerx=False,
        show_progress_bar=True):
    nb_ex = len(src_data)

    assert constraints_fn_list is None or len(constraints_fn_list) == nb_ex

    if show_progress_bar:
        range_creator = tqdm.trange
    else:
        range_creator = range

    for num_ex in range_creator(nb_ex):
        src_batch, src_mask = make_batch_src([src_data[num_ex]],
                                             gpu=gpu,
                                             use_chainerx=use_chainerx)

        assert len(src_mask) == 0
        if nb_steps_ratio is not None:
            nb_steps = int(len(src_data[num_ex]) * nb_steps_ratio) + 1

#         if isinstance(encdec, (tuple, list)):
#             assert len(encdec) == 1
#             encdec = encdec[0]
#
#         translations = encdec.beam_search(src_batch, src_mask, nb_steps = nb_steps, eos_idx = eos_idx,
#                                           beam_width = beam_width,
#                                           beam_opt = beam_opt, need_attention = need_attention,
#                                     groundhog = groundhog)

        if not isinstance(encdec, (tuple, list)):
            encdec = [encdec]

        if constraints_fn_list is not None:
            constraints_fn = constraints_fn_list[num_ex].get(
                "ph_constraint", None)
        else:
            constraints_fn = None
        translations = beam_search.ensemble_beam_search(
            encdec,
            src_batch,
            src_mask,
            nb_steps=nb_steps,
            eos_idx=eos_idx,
            beam_search_params=beam_search_params,
            #beam_width=beam_width,
            #beam_pruning_margin=beam_pruning_margin,
            #beam_score_length_normalization=beam_score_length_normalization,
            #beam_score_length_normalization_strength=beam_score_length_normalization_strength,
            #beam_score_coverage_penalty=beam_score_coverage_penalty,
            #beam_score_coverage_penalty_strength=beam_score_coverage_penalty_strength,
            need_attention=need_attention,
            #force_finish=force_finish,
            prob_space_combination=prob_space_combination,
            #use_unfinished_translation_if_none_found=use_unfinished_translation_if_none_found,
            constraints=constraints_fn,
            use_astar=use_astar,
            astar_params=astar_params,
            gpu=gpu)

        # TODO: This is a quick patch, but actually ensemble_beam_search probably should not return empty translations except when no translation found
        if len(translations) > 1:
            translations = [t for t in translations if len(t[0]) > 0]


#         print("nb_trans", len(translations), [score for _, score in translations])
#         translations.sort(key = itemgetter(1), reverse = True)

        if reverse_encdec is not None and len(translations) > 1:
            rescored_translations = []
            reverse_scores = reverse_rescore(reverse_encdec,
                                             src_batch,
                                             src_mask,
                                             eos_idx,
                                             [t[0] for t in translations],
                                             gpu,
                                             use_chainerx=use_chainerx)
            for num_t in six.moves.range(len(translations)):
                tr, sc, attn = translations[num_t]
                rescored_translations.append(
                    (tr, sc + reverse_scores[num_t], attn))
            translations = rescored_translations

        xp = encdec[0].xp

        if post_score_length_normalization == 'none' and post_score_coverage_penalty == 'none':
            ranking_criterion = operator.itemgetter(1)
        else:
            ONE_ON_DEVICE = beam_search.convert_array_if_needed(
                np.array(1.0, dtype=np.float32), xp, gpu)

            def ranking_criterion(x):
                length_normalization = 1
                if post_score_length_normalization == 'simple':
                    length_normalization = len(x[0]) + 1
                elif post_score_length_normalization == 'google':
                    length_normalization = pow(
                        (len(x[0]) + 5),
                        post_score_length_normalization_strength) / pow(
                            6, post_score_length_normalization_strength)

                dic_score = 0
                dic_score_computer = (constraints_fn_list[num_ex].get(
                    "dic_constraint", None) if constraints_fn_list is not None
                                      else None)
                if dic_score_computer is not None:
                    dic_score = dic_score_computer(x[0])

                coverage_penalty = 0
                if post_score_coverage_penalty == 'google':
                    assert len(src_data[num_ex]) == x[2][0].shape[0]

                    # log.info("sum={0}".format(sum(x[2])))
                    # log.info("min={0}".format(xp.minimum(sum(x[2]), xp.array(1.0))))
                    # log.info("log={0}".format(xp.log(xp.minimum(sum(x[2]), xp.array(1.0)))))
                    log_of_min_of_sum_over_j = xp.log(
                        xp.minimum(sum(x[2]), ONE_ON_DEVICE))
                    coverage_penalty = post_score_coverage_penalty_strength * xp.sum(
                        log_of_min_of_sum_over_j)
                    # log.info("cp={0}".format(coverage_penalty))
                    # cp = 0
                    # for i in six.moves.range(len(src_data[num_ex])):
                    #    attn_sum = 0
                    #    for j in six.moves.range(len(x[0])):
                    #        attn_sum += x[2][j][i]
                    #    #log.info("attn_sum={0}".format(attn_sum))
                    #    #log.info("min={0}".format(min(attn_sum, 1.0)))
                    #    #log.info("log={0}".format(math.log(min(attn_sum, 1.0))))
                    #    cp += math.log(min(attn_sum, 1.0))
                    # log.info("cp={0}".format(cp))
                    # cp *= post_score_coverage_penalty_strength

                    # slow = x[1]/length_normalization + cp
                    # opti = x[1]/length_normalization + coverage_penalty
                    # log.info("type={0}....{1}".format(type(slow), type(opti)))
                    # log.info("shape={0} size={1} dim={2} data={3} elem={4}".format(opti.shape, opti.size, opti.ndim, opti.data, opti.item(0)))
                    # test = '!!!'
                    # if "{0}".format(slow) == "{0}".format(opti):
                    #    test = ''
                    # log.info("score slow <=> optimized: {0} <=> {1} {2}".format(slow, opti, test))

                return x[
                    1] / length_normalization + coverage_penalty + dic_score

        translations.sort(key=ranking_criterion, reverse=True)

        if nbest is not None:
            yield translations[:nbest]
        else:
            yield [translations[0]]
Exemple #5
0
def greedy_batch_translate(encdec,
                           eos_idx,
                           src_data,
                           batch_size=80,
                           gpu=None,
                           get_attention=False,
                           nb_steps=50,
                           reverse_src=False,
                           reverse_tgt=False,
                           use_chainerx=False):
    with chainer.using_config("train", False), chainer.no_backprop_mode():
        if encdec.encdec_type() == "ff":
            result = encdec.greedy_batch_translate(src_data,
                                                   mb_size=batch_size,
                                                   nb_steps=nb_steps)
            if get_attention:
                dummy_attention = []
                for src, tgt in six.moves.zip(src_data, result):
                    dummy_attention.append(
                        np.zeros((len(src), len(tgt)), dtype=np.float32))
                return result, dummy_attention
            else:
                return result

        nb_ex = len(src_data)
        nb_batch = nb_ex // batch_size + (1 if nb_ex % batch_size != 0 else 0)
        res = []
        attn_all = []
        for i in six.moves.range(nb_batch):
            current_batch_raw_data = src_data[i * batch_size:(i + 1) *
                                              batch_size]

            if reverse_src:
                current_batch_raw_data_new = []
                for src_side in current_batch_raw_data:
                    current_batch_raw_data_new.append(src_side[::-1])
                current_batch_raw_data = current_batch_raw_data_new

            src_batch, src_mask = make_batch_src(current_batch_raw_data,
                                                 gpu=gpu,
                                                 use_chainerx=use_chainerx)
            sample_greedy, score, attn_list = encdec(
                src_batch,
                nb_steps,
                src_mask,
                use_best_for_sample=True,
                keep_attn_values=get_attention)
            deb = de_batch(sample_greedy,
                           mask=None,
                           eos_idx=eos_idx,
                           is_variable=False)
            res += deb
            if get_attention:
                deb_attn = de_batch(attn_list,
                                    mask=None,
                                    eos_idx=None,
                                    is_variable=True,
                                    raw=True,
                                    reverse=reverse_tgt)
                attn_all += deb_attn

        if reverse_tgt:
            new_res = []
            for t in res:
                if t[-1] == eos_idx:
                    new_res.append(t[:-1][::-1] + [t[-1]])
                else:
                    new_res.append(t[::-1])

            res = new_res

        if get_attention:
            assert not reverse_tgt, "not implemented"
            return res, attn_all
        else:
            return res
Exemple #6
0
def beam_search_translate(encdec,
                          eos_idx,
                          src_data,
                          beam_width=20,
                          beam_pruning_margin=None,
                          nb_steps=50,
                          gpu=None,
                          beam_score_coverage_penalty=None,
                          beam_score_coverage_penalty_strength=0.2,
                          need_attention=False,
                          nb_steps_ratio=None,
                          beam_score_length_normalization='none',
                          beam_score_length_normalization_strength=0.2,
                          post_score_length_normalization='simple',
                          post_score_length_normalization_strength=0.2,
                          post_score_coverage_penalty='none',
                          post_score_coverage_penalty_strength=0.2,
                          groundhog=False,
                          force_finish=False,
                          prob_space_combination=False,
                          reverse_encdec=None,
                          use_unfinished_translation_if_none_found=False,
                          nbest=None):
    nb_ex = len(src_data)
    for num_ex in range(nb_ex):
        src_batch, src_mask = make_batch_src([src_data[num_ex]],
                                             gpu=gpu,
                                             volatile="on")
        assert len(src_mask) == 0
        if nb_steps_ratio is not None:
            nb_steps = int(len(src_data[num_ex]) * nb_steps_ratio) + 1

#         if isinstance(encdec, (tuple, list)):
#             assert len(encdec) == 1
#             encdec = encdec[0]
#
#         translations = encdec.beam_search(src_batch, src_mask, nb_steps = nb_steps, eos_idx = eos_idx,
#                                           beam_width = beam_width,
#                                           beam_opt = beam_opt, need_attention = need_attention,
#                                     groundhog = groundhog)

        if not isinstance(encdec, (tuple, list)):
            encdec = [encdec]
        translations = beam_search.ensemble_beam_search(
            encdec,
            src_batch,
            src_mask,
            nb_steps=nb_steps,
            eos_idx=eos_idx,
            beam_width=beam_width,
            beam_pruning_margin=beam_pruning_margin,
            beam_score_length_normalization=beam_score_length_normalization,
            beam_score_length_normalization_strength=
            beam_score_length_normalization_strength,
            beam_score_coverage_penalty=beam_score_coverage_penalty,
            beam_score_coverage_penalty_strength=
            beam_score_coverage_penalty_strength,
            need_attention=need_attention,
            force_finish=force_finish,
            prob_space_combination=prob_space_combination,
            use_unfinished_translation_if_none_found=
            use_unfinished_translation_if_none_found)

        # TODO: This is a quick patch, but actually ensemble_beam_search probably should not return empty translations except when no translation found
        if len(translations) > 1:
            translations = [t for t in translations if len(t[0]) > 0]


#         print "nb_trans", len(translations), [score for _, score in translations]
#         translations.sort(key = itemgetter(1), reverse = True)

        if reverse_encdec is not None and len(translations) > 1:
            rescored_translations = []
            reverse_scores = reverse_rescore(reverse_encdec, src_batch,
                                             src_mask, eos_idx,
                                             [t[0] for t in translations], gpu)
            for num_t in xrange(len(translations)):
                tr, sc, attn = translations[num_t]
                rescored_translations.append(
                    (tr, sc + reverse_scores[num_t], attn))
            translations = rescored_translations

        xp = encdec[0].xp

        if post_score_length_normalization == 'none' and post_score_coverage_penalty == 'none':
            ranking_criterion = operator.itemgetter(1)
        else:

            def ranking_criterion(x):
                length_normalization = 1
                if post_score_length_normalization == 'simple':
                    length_normalization = len(x[0]) + 1
                elif post_score_length_normalization == 'google':
                    length_normalization = pow(
                        (len(x[0]) + 5),
                        post_score_length_normalization_strength) / pow(
                            6, post_score_length_normalization_strength)

                coverage_penalty = 0
                if post_score_coverage_penalty == 'google':
                    assert len(src_data[num_ex]) == x[2][0].shape[0]

                    # log.info("sum={0}".format(sum(x[2])))
                    # log.info("min={0}".format(xp.minimum(sum(x[2]), xp.array(1.0))))
                    # log.info("log={0}".format(xp.log(xp.minimum(sum(x[2]), xp.array(1.0)))))
                    log_of_min_of_sum_over_j = xp.log(
                        xp.minimum(sum(x[2]), xp.array(1.0)))
                    coverage_penalty = post_score_coverage_penalty_strength * xp.sum(
                        log_of_min_of_sum_over_j)
                    # log.info("cp={0}".format(coverage_penalty))
                    # cp = 0
                    # for i in xrange(len(src_data[num_ex])):
                    #    attn_sum = 0
                    #    for j in xrange(len(x[0])):
                    #        attn_sum += x[2][j][i]
                    #    #log.info("attn_sum={0}".format(attn_sum))
                    #    #log.info("min={0}".format(min(attn_sum, 1.0)))
                    #    #log.info("log={0}".format(math.log(min(attn_sum, 1.0))))
                    #    cp += math.log(min(attn_sum, 1.0))
                    # log.info("cp={0}".format(cp))
                    # cp *= post_score_coverage_penalty_strength

                    # slow = x[1]/length_normalization + cp
                    # opti = x[1]/length_normalization + coverage_penalty
                    # log.info("type={0}....{1}".format(type(slow), type(opti)))
                    # log.info("shape={0} size={1} dim={2} data={3} elem={4}".format(opti.shape, opti.size, opti.ndim, opti.data, opti.item(0)))
                    # test = '!!!'
                    # if "{0}".format(slow) == "{0}".format(opti):
                    #    test = ''
                    # log.info("score slow <=> optimized: {0} <=> {1} {2}".format(slow, opti, test))

                return x[1] / length_normalization + coverage_penalty

        translations.sort(key=ranking_criterion, reverse=True)

        if nbest is not None:
            yield translations[:nbest]
        else:
            yield [translations[0]]
Exemple #7
0
def greedy_batch_translate(encdec,
                           eos_idx,
                           src_data,
                           batch_size=80,
                           gpu=None,
                           get_attention=False,
                           nb_steps=50,
                           reverse_src=False,
                           reverse_tgt=False):
    nb_ex = len(src_data)
    nb_batch = nb_ex / batch_size + (1 if nb_ex % batch_size != 0 else 0)
    res = []
    attn_all = []
    for i in range(nb_batch):
        current_batch_raw_data = src_data[i * batch_size:(i + 1) * batch_size]

        if reverse_src:
            current_batch_raw_data_new = []
            for src_side in current_batch_raw_data:
                current_batch_raw_data_new.append(src_side[::-1])
            current_batch_raw_data = current_batch_raw_data_new

        src_batch, src_mask = make_batch_src(current_batch_raw_data,
                                             gpu=gpu,
                                             volatile="on")
        sample_greedy, score, attn_list = encdec(
            src_batch,
            nb_steps,
            src_mask,
            use_best_for_sample=True,
            keep_attn_values=get_attention,
            mode="test")
        deb = de_batch(sample_greedy,
                       mask=None,
                       eos_idx=eos_idx,
                       is_variable=False)
        res += deb
        if get_attention:
            deb_attn = de_batch(attn_list,
                                mask=None,
                                eos_idx=None,
                                is_variable=True,
                                raw=True,
                                reverse=reverse_tgt)
            attn_all += deb_attn

    if reverse_tgt:
        new_res = []
        for t in res:
            if t[-1] == eos_idx:
                new_res.append(t[:-1][::-1] + [t[-1]])
            else:
                new_res.append(t[::-1])

        res = new_res

    if get_attention:
        assert not reverse_tgt, "not implemented"
        return res, attn_all
    else:
        return res
Exemple #8
0
def do_eval(config_eval):
    src_fn = config_eval.process.src_fn
    tgt_fn = config_eval.output.tgt_fn
    mode = config_eval.method.mode
    gpu = config_eval.process.gpu
    dest_fn = config_eval.process.dest_fn
    mb_size = config_eval.process.mb_size
    nb_steps = config_eval.method.nb_steps
    nb_steps_ratio = config_eval.method.nb_steps_ratio
    max_nb_ex = config_eval.process.max_nb_ex
    nbest_to_rescore = config_eval.output.nbest_to_rescore
    nbest = config_eval.output.nbest

    beam_width = config_eval.method.beam_width
    beam_pruning_margin = config_eval.method.beam_pruning_margin
    beam_score_length_normalization = config_eval.method.beam_score_length_normalization
    beam_score_length_normalization_strength = config_eval.method.beam_score_length_normalization_strength
    beam_score_coverage_penalty = config_eval.beam_score_coverage_penalty
    beam_score_coverage_penalty_strength = config_eval.beam_score_coverage_penalty_strength
    post_score_length_normalization = config_eval.method.post_score_length_normalization
    post_score_length_normalization_strength = config_eval.method.post_score_length_normalization_strength
    groundhog = config_eval.method.groundhog
    tgt_unk_id = config_eval.output.tgt_unk_id
    force_finish = config_eval.method.force_finish
    prob_space_combination = config_eval.method.prob_space_combination
    generate_attention_html = config_eval.output.generate_attention_html
    rich_output_filename = config_eval.output.rich_output_filename

    ref = config_eval.output.ref
    dic = config_eval.output.dic
    normalize_unicode_unk = config_eval.output.normalize_unicode_unk
    attempt_to_relocate_unk_source = config_eval.output.attempt_to_relocate_unk_source
    remove_unk = config_eval.output.remove_unk

    post_score_coverage_penalty = config_eval.method.post_score_coverage_penalty
    post_score_coverage_penalty_strength = config_eval.method.post_score_coverage_penalty_strength

    time_start = time.clock()

    encdec_list, eos_idx, src_indexer, tgt_indexer, reverse_encdec, model_infos_list = create_encdec(config_eval)

    if config_eval.process.server is None:
        eval_dir_placeholder = "@eval@/"
        if dest_fn.startswith(eval_dir_placeholder):
            if config_eval.trained_model is not None:
                training_model_filename = config_eval.trained_model
            else:
                if len(config_eval.process.load_model_config) == 0:
                    log.error("Cannot detect value for $eval$ placeholder")
                    sys.exit(1)
                training_model_filename = config_eval.process.load_model_config[0]

            eval_dir = os.path.join(os.path.dirname(training_model_filename), "eval")
            dest_fn = os.path.join(eval_dir, dest_fn[len(eval_dir_placeholder):])
            log.info("$eval$ detected. dest_fn is: %s ", dest_fn)
            ensure_path(eval_dir)

        if src_fn is None:
            (dev_src_from_config, dev_tgt_from_config, test_src_from_config, test_tgt_from_config) = get_src_tgt_dev_from_config_eval(config_eval)
            if test_src_from_config is None:
                log.error("Could not find value for source text, either on command line or in config files")
                sys.exit(1)
            log.info("using files from config as src:%s", test_src_from_config)
            src_fn = test_src_from_config
            if ref is None:
                log.info("using files from config as ref:%s", test_tgt_from_config)
                ref = test_tgt_from_config

        log.info("opening source file %s" % src_fn)
        src_data, stats_src_pp = build_dataset_one_side_pp(src_fn, src_pp=src_indexer,
                                                           max_nb_ex=max_nb_ex)
        log.info("src data stats:\n%s", stats_src_pp.make_report())

    if dest_fn is not None:
        save_eval_config_fn = dest_fn + ".eval.init.config.json"
        log.info("Saving initial eval config to %s" % save_eval_config_fn)
        config_eval.save_to(save_eval_config_fn)

    translation_infos = OrderedNamespace()
#     log.info("%i sentences loaded" % make_data_infos.nb_ex)
#     log.info("#tokens src: %i   of which %i (%f%%) are unknown"%(make_data_infos.total_token,
#                                                                  make_data_infos.total_count_unk,
#                                                                  float(make_data_infos.total_count_unk * 100) /
#                                                                     make_data_infos.total_token))

    tgt_data = None
    if tgt_fn is not None:
        log.info("opening target file %s" % tgt_fn)
        tgt_data, stats_tgt_pp = build_dataset_one_side_pp(tgt_fn, src_pp=tgt_indexer,
                                                           max_nb_ex=max_nb_ex)
        log.info("tgt data stats:\n%s", stats_tgt_pp.make_report())
#         log.info("%i sentences loaded"%make_data_infos.nb_ex)
#         log.info("#tokens src: %i   of which %i (%f%%) are unknown"%(make_data_infos.total_token,
#                                                                  make_data_infos.total_count_unk,
#                                                                  float(make_data_infos.total_count_unk * 100) /
#                                                                     make_data_infos.total_token))

#     translations = greedy_batch_translate(encdec, eos_idx, src_data, batch_size = mb_size, gpu = args.gpu)

    translation_infos["src"] = src_fn
    translation_infos["tgt"] = tgt_fn
    translation_infos["ref"] = ref

    for num_model, model_infos in enumerate(model_infos_list):
        translation_infos["model%i" % num_model] = model_infos

    time_all_loaded = time.clock()

    if mode == "translate":
        log.info("writing translation of to %s" % dest_fn)
        with cuda.get_device(gpu):
            assert len(encdec_list) == 1
            translations = greedy_batch_translate(
                encdec_list[0], eos_idx, src_data, batch_size=mb_size, gpu=gpu, nb_steps=nb_steps)
        out = codecs.open(dest_fn, "w", encoding="utf8")
        for t in translations:
            if t[-1] == eos_idx:
                t = t[:-1]
            ct = tgt_indexer.deconvert(t, unk_tag="#T_UNK#")
#             ct = convert_idx_to_string(t, tgt_voc + ["#T_UNK#"])
            out.write(ct + "\n")

    elif mode == "beam_search" or mode == "eval_bleu":
        if config_eval.process.server is not None:
            from nmt_chainer.translation.server import do_start_server
            do_start_server(config_eval)
        else:
            translate_to_file_with_beam_search(dest_fn, gpu, encdec_list, eos_idx, src_data,
                                               beam_width=beam_width,
                                               beam_pruning_margin=beam_pruning_margin,
                                               beam_score_coverage_penalty=beam_score_coverage_penalty,
                                               beam_score_coverage_penalty_strength=beam_score_coverage_penalty_strength,
                                               nb_steps=nb_steps,
                                               nb_steps_ratio=nb_steps_ratio,
                                               beam_score_length_normalization=beam_score_length_normalization,
                                               beam_score_length_normalization_strength=beam_score_length_normalization_strength,
                                               post_score_length_normalization=post_score_length_normalization,
                                               post_score_length_normalization_strength=post_score_length_normalization_strength,
                                               post_score_coverage_penalty=post_score_coverage_penalty,
                                               post_score_coverage_penalty_strength=post_score_coverage_penalty_strength,
                                               groundhog=groundhog,
                                               tgt_unk_id=tgt_unk_id,
                                               tgt_indexer=tgt_indexer,
                                               force_finish=force_finish,
                                               prob_space_combination=prob_space_combination,
                                               reverse_encdec=reverse_encdec,
                                               generate_attention_html=generate_attention_html,
                                               src_indexer=src_indexer,
                                               rich_output_filename=rich_output_filename,
                                               use_unfinished_translation_if_none_found=True,
                                               unprocessed_output_filename=dest_fn + ".unprocessed",
                                               nbest=nbest)

            translation_infos["dest"] = dest_fn
            translation_infos["unprocessed"] = dest_fn + ".unprocessed"
            if mode == "eval_bleu":
                if ref is not None:
                    bc = bleu_computer.get_bc_from_files(ref, dest_fn)
                    print "bleu before unk replace:", bc
                    translation_infos["bleu"] = bc.bleu()
                    translation_infos["bleu_infos"] = str(bc)
                else:
                    print "bleu before unk replace: No Ref Provided"

                from nmt_chainer.utilities import replace_tgt_unk
                replace_tgt_unk.replace_unk(dest_fn, src_fn, dest_fn + ".unk_replaced", dic, remove_unk,
                                            normalize_unicode_unk,
                                            attempt_to_relocate_unk_source)
                translation_infos["unk_replaced"] = dest_fn + ".unk_replaced"

                if ref is not None:
                    bc = bleu_computer.get_bc_from_files(ref, dest_fn + ".unk_replaced")
                    print "bleu after unk replace:", bc
                    translation_infos["post_unk_bleu"] = bc.bleu()
                    translation_infos["post_unk_bleu_infos"] = str(bc)
                else:
                    print "bleu before unk replace: No Ref Provided"

    elif mode == "translate_attn":
        log.info("writing translation + attention as html to %s" % dest_fn)
        with cuda.get_device(gpu):
            assert len(encdec_list) == 1
            translations, attn_all = greedy_batch_translate(
                encdec_list[0], eos_idx, src_data, batch_size=mb_size, gpu=gpu,
                get_attention=True, nb_steps=nb_steps)
#         tgt_voc_with_unk = tgt_voc + ["#T_UNK#"]
#         src_voc_with_unk = src_voc + ["#S_UNK#"]
        assert len(translations) == len(src_data)
        assert len(attn_all) == len(src_data)
        attn_vis = AttentionVisualizer()
        for num_t in xrange(len(src_data)):
            src_idx_list = src_data[num_t]
            tgt_idx_list = translations[num_t][:-1]
            attn = attn_all[num_t]
#             assert len(attn) == len(tgt_idx_list)

            src_w = src_indexer.deconvert_swallow(src_idx_list, unk_tag="#S_UNK#") + ["SUM_ATTN"]
            tgt_w = tgt_indexer.deconvert_swallow(tgt_idx_list, unk_tag="#T_UNK#")
#             src_w = [src_voc_with_unk[idx] for idx in src_idx_list] + ["SUM_ATTN"]
#             tgt_w = [tgt_voc_with_unk[idx] for idx in tgt_idx_list]
#             for j in xrange(len(tgt_idx_list)):
#                 tgt_idx_list.append(tgt_voc_with_unk[t_and_attn[j][0]])
#
    #         print [src_voc_with_unk[idx] for idx in src_idx_list], tgt_idx_list

            attn_vis.add_plot(src_w, tgt_w, attn)

        attn_vis.make_plot(dest_fn)

    elif mode == "align":
        import nmt_chainer.utilities.visualisation as visualisation
        assert tgt_data is not None
        assert len(tgt_data) == len(src_data)
        log.info("writing alignment as html to %s" % dest_fn)
        with cuda.get_device(gpu):
            assert len(encdec_list) == 1
            loss, attn_all = batch_align(
                encdec_list[0], eos_idx, zip(src_data, tgt_data), batch_size=mb_size, gpu=gpu)
#         tgt_voc_with_unk = tgt_voc + ["#T_UNK#"]
#         src_voc_with_unk = src_voc + ["#S_UNK#"]

        assert len(attn_all) == len(src_data)
        plots_list = []
        for num_t in xrange(len(src_data)):
            src_idx_list = src_data[num_t]
            tgt_idx_list = tgt_data[num_t]
            attn = attn_all[num_t]
#             assert len(attn) == len(tgt_idx_list)

            alignment = np.zeros((len(src_idx_list) + 1, len(tgt_idx_list)))
            sum_al = [0] * len(tgt_idx_list)
            for i in xrange(len(src_idx_list)):
                for j in xrange(len(tgt_idx_list)):
                    alignment[i, j] = attn[j][i]
                    sum_al[j] += alignment[i, j]
            for j in xrange(len(tgt_idx_list)):
                alignment[len(src_idx_list), j] = sum_al[j]

            src_w = src_indexer.deconvert_swallow(src_idx_list, unk_tag="#S_UNK#") + ["SUM_ATTN"]
            tgt_w = tgt_indexer.deconvert_swallow(tgt_idx_list, unk_tag="#T_UNK#")
#             src_w = [src_voc_with_unk[idx] for idx in src_idx_list] + ["SUM_ATTN"]
#             tgt_w = [tgt_voc_with_unk[idx] for idx in tgt_idx_list]
#             for j in xrange(len(tgt_idx_list)):
#                 tgt_idx_list.append(tgt_voc_with_unk[t_and_attn[j][0]])
#
    #         print [src_voc_with_unk[idx] for idx in src_idx_list], tgt_idx_list
            p1 = visualisation.make_alignment_figure(
                src_w, tgt_w, alignment)
#             p2 = visualisation.make_alignment_figure(
#                             [src_voc_with_unk[idx] for idx in src_idx_list], tgt_idx_list, alignment)
            plots_list.append(p1)
        p_all = visualisation.Column(*plots_list)
        visualisation.output_file(dest_fn)
        visualisation.show(p_all)
#     for t in translations_with_attn:
#         for x, attn in t:
#             print x, attn


#         out.write(convert_idx_to_string([x for x, attn in t], tgt_voc + ["#T_UNK#"]) + "\n")

    elif mode == "score_nbest":
        log.info("opening nbest file %s" % nbest_to_rescore)
        nbest_f = codecs.open(nbest_to_rescore, encoding="utf8")
        nbest_list = [[]]
        for line in nbest_f:
            line = line.strip().split("|||")
            num_src = int(line[0].strip())
            if num_src >= len(nbest_list):
                assert num_src == len(nbest_list)
                if max_nb_ex is not None and num_src >= max_nb_ex:
                    break
                nbest_list.append([])
            else:
                assert num_src == len(nbest_list) - 1
            sentence = line[1].strip()
            nbest_list[-1].append(sentence.split(" "))

        log.info("found nbest lists for %i source sentences" % len(nbest_list))
        nbest_converted, make_data_infos = make_data.build_dataset_for_nbest_list_scoring(tgt_indexer, nbest_list)
        log.info("total %i sentences loaded" % make_data_infos.nb_ex)
        log.info("#tokens src: %i   of which %i (%f%%) are unknown" % (make_data_infos.total_token,
                                                                       make_data_infos.total_count_unk,
                                                                       float(make_data_infos.total_count_unk * 100) /
                                                                       make_data_infos.total_token))
        if len(nbest_list) != len(src_data[:max_nb_ex]):
            log.warn("mismatch in lengths nbest vs src : %i != %i" % (len(nbest_list), len(src_data[:max_nb_ex])))
            assert len(nbest_list) == len(src_data[:max_nb_ex])

        log.info("starting scoring")
        from nmt_chainer.utilities import utils
        res = []
        for num in xrange(len(nbest_converted)):
            if num % 200 == 0:
                print >>sys.stderr, num,
            elif num % 50 == 0:
                print >>sys.stderr, "*",

            res.append([])
            src, tgt_list = src_data[num], nbest_converted[num]
            src_batch, src_mask = utils.make_batch_src([src], gpu=gpu, volatile="on")

            assert len(encdec_list) == 1
            scorer = encdec_list[0].nbest_scorer(src_batch, src_mask)

            nb_batches = (len(tgt_list) + mb_size - 1) / mb_size
            for num_batch in xrange(nb_batches):
                tgt_batch, arg_sort = utils.make_batch_tgt(tgt_list[num_batch * nb_batches: (num_batch + 1) * nb_batches],
                                                           eos_idx=eos_idx, gpu=gpu, volatile="on", need_arg_sort=True)
                scores, attn = scorer(tgt_batch)
                scores, _ = scores
                scores = scores.data

                assert len(arg_sort) == len(scores)
                de_sorted_scores = [None] * len(scores)
                for xpos in xrange(len(arg_sort)):
                    original_pos = arg_sort[xpos]
                    de_sorted_scores[original_pos] = scores[xpos]
                res[-1] += de_sorted_scores
        print >>sys.stderr
        log.info("writing scores to %s" % dest_fn)
        out = codecs.open(dest_fn, "w", encoding="utf8")
        for num in xrange(len(res)):
            for score in res[num]:
                out.write("%i %f\n" % (num, score))

    time_end = time.clock()
    translation_infos["loading_time"] = time_all_loaded - time_start
    translation_infos["translation_time"] = time_end - time_all_loaded
    translation_infos["total_time"] = time_end - time_start
    if dest_fn is not None:
        config_eval_session = config_eval.copy(readonly=False)
        config_eval_session.add_section("translation_infos", keep_at_bottom="metadata")
        config_eval_session["translation_infos"] = translation_infos
        config_eval_session.set_metadata_modified_time()
        save_eval_config_fn = dest_fn + ".eval.config.json"
        log.info("Saving eval config to %s" % save_eval_config_fn)
        config_eval_session.save_to(save_eval_config_fn)