Beispiel #1
0
    def test_1(self):
        import nmt_chainer.translation.evaluation as evaluation
        Vi, Ei, Hi, Vo, Eo, Ho, Ha, Hl = 29, 37, 13, 53, 7, 12, 19, 33
        encdec = nmt_chainer.models.encoder_decoder.EncoderDecoder(
            Vi, Ei, Hi, Vo, Eo, Ho, Ha, Hl)
        eos_idx = Vo - 1
        src_data = [[2, 3, 3, 4, 4, 5], [1, 3, 8, 9, 2]]
        #         best1_gen = evaluation.beam_search_translate(encdec, eos_idx, src_data, beam_width = 10, nb_steps = 15, gpu = None, beam_opt = False,
        #                           need_attention = False)
        #         best2_gen = evaluation.beam_search_translate(encdec, eos_idx, src_data, beam_width = 10, nb_steps = 15, gpu = None, beam_opt = True,
        #                           need_attention = False)

        # TODO: not much point to this test now that beam_opt distinction is removed
        import nmt_chainer.translation.beam_search as beam_search
        best1_gen = evaluation.beam_search_translate(
            encdec,
            eos_idx,
            src_data,
            beam_search.BeamSearchParams(beam_width=10),
            nb_steps=15,
            gpu=None,
            need_attention=False)
        best2_gen = evaluation.beam_search_translate(
            encdec,
            eos_idx,
            src_data,
            beam_search.BeamSearchParams(beam_width=10),
            nb_steps=15,
            gpu=None,
            need_attention=False)
        res1a, res1b = next(best1_gen), next(best2_gen)
        res2a, res2b = next(best1_gen), next(best2_gen)
Beispiel #2
0
            def translate_closure(beam_width, nb_steps_ratio):
                beam_search_params = beam_search.BeamSearchParams(
                    beam_width=beam_width,
                    beam_pruning_margin=beam_pruning_margin,
                    beam_score_coverage_penalty=beam_score_coverage_penalty,
                    beam_score_coverage_penalty_strength=
                    beam_score_coverage_penalty_strength,
                    beam_score_length_normalization=
                    beam_score_length_normalization,
                    beam_score_length_normalization_strength=
                    beam_score_length_normalization_strength,
                    force_finish=force_finish,
                    use_unfinished_translation_if_none_found=True,
                    always_consider_eos_and_placeholders=
                    always_consider_eos_and_placeholders)

                translate_to_file_with_beam_search(
                    dest_fn,
                    gpu,
                    encdec_list,
                    eos_idx,
                    src_data,
                    beam_search_params=beam_search_params,
                    nb_steps=nb_steps,
                    nb_steps_ratio=nb_steps_ratio,
                    post_score_length_normalization=
                    post_score_length_normalization,
                    post_score_length_normalization_strength=
                    post_score_length_normalization_strength,
                    post_score_coverage_penalty=post_score_coverage_penalty,
                    post_score_coverage_penalty_strength=
                    post_score_coverage_penalty_strength,
                    groundhog=groundhog,
                    tgt_unk_id=tgt_unk_id,
                    tgt_indexer=tgt_indexer,
                    prob_space_combination=prob_space_combination,
                    reverse_encdec=reverse_encdec,
                    generate_attention_html=generate_attention_html,
                    src_indexer=src_indexer,
                    rich_output_filename=rich_output_filename,
                    unprocessed_output_filename=dest_fn + ".unprocessed",
                    nbest=nbest,
                    constraints_fn_list=constraints_list,
                    use_astar=(mode == "astar_search"
                               or mode == "astar_eval_bleu"),
                    astar_params=astar_params,
                    use_chainerx=config_eval.process.use_chainerx)

                translation_infos["dest"] = dest_fn
                translation_infos["unprocessed"] = dest_fn + ".unprocessed"
                if mode == "eval_bleu" or mode == "astar_eval_bleu":
                    if ref is not None:
                        bc = bleu_computer.get_bc_from_files(ref, dest_fn)
                        print("bleu before unk replace:", bc)
                        translation_infos["bleu"] = bc.bleu()
                        translation_infos["bleu_infos"] = str(bc)
                    else:
                        print("bleu before unk replace: No Ref Provided")

                    from nmt_chainer.utilities import replace_tgt_unk
                    replace_tgt_unk.replace_unk(
                        dest_fn, src_fn, dest_fn + ".unk_replaced", dic,
                        remove_unk, normalize_unicode_unk,
                        attempt_to_relocate_unk_source)
                    translation_infos[
                        "unk_replaced"] = dest_fn + ".unk_replaced"

                    if ref is not None:
                        bc = bleu_computer.get_bc_from_files(
                            ref, dest_fn + ".unk_replaced")
                        print("bleu after unk replace:", bc)
                        translation_infos["post_unk_bleu"] = bc.bleu()
                        translation_infos["post_unk_bleu_infos"] = str(bc)
                    else:
                        print("bleu before unk replace: No Ref Provided")
                    return -bc.bleu()
                else:
                    return None