コード例 #1
0
 def test_parse_numberize(self):
     src_dataset = data.InMemoryNumpyDataset()
     trg_dataset = data.InMemoryNumpyDataset()
     for _ in range(2):
         src_dataset.parse(
             self.src_txt_numberized,
             self.d,
             reverse_order=True,
             append_eos=False,
             already_numberized=True,
         )
         trg_dataset.parse(
             self.trg_txt_numberized,
             self.d,
             reverse_order=False,
             append_eos=True,
             already_numberized=True,
         )
         self.assertEqual(self.num_sentences, len(src_dataset))
         self.assertEqual(self.num_sentences, len(trg_dataset))
         for i in range(self.num_sentences):
             self.assertListEqual(self.src_ref[i], src_dataset[i].tolist())
             self.assertListEqual(
                 self.trg_ref[i] + [self.d.eos_index], trg_dataset[i].tolist()
             )
コード例 #2
0
 def test_parse_multiling(self):
     prepend_dataset = data.InMemoryNumpyDataset()
     append_dataset = data.InMemoryNumpyDataset()
     corpora = [
         data.MultilingualCorpusConfig(
             dialect_id=10, data_file=self.trg_txt, dict=self.d, oversampling=1
         ),
         data.MultilingualCorpusConfig(
             dialect_id=11, data_file=self.trg_txt, dict=self.d, oversampling=1
         ),
     ]
     lang1 = corpora[0].dialect_id
     lang2 = corpora[1].dialect_id
     prepend_dataset.parse_multilingual(
         corpora, reverse_order=False, append_eos=False, prepend_language_id=True
     )
     append_dataset.parse_multilingual(
         corpora, reverse_order=False, append_eos=False, prepend_language_id=False
     )
     self.assertEqual(2 * self.num_sentences, len(prepend_dataset))
     self.assertEqual(2 * self.num_sentences, len(append_dataset))
     for i in range(self.num_sentences):
         self.assertListEqual([lang1] + self.trg_ref[i], prepend_dataset[i].tolist())
         self.assertListEqual(self.trg_ref[i] + [lang1], append_dataset[i].tolist())
         self.assertListEqual(
             [lang2] + self.trg_ref[i],
             prepend_dataset[i + self.num_sentences].tolist(),
         )
         self.assertListEqual(
             self.trg_ref[i] + [lang2],
             append_dataset[i + self.num_sentences].tolist(),
         )
コード例 #3
0
ファイル: preprocess.py プロジェクト: flexpad/translate
def binarize_text_file(
    text_file: str,
    dictionary: Dictionary,
    output_path: str,
    append_eos: bool,
    reverse_order: bool,
    use_char_data: bool = False,
    char_dictionary: Optional[Dictionary] = None,
    already_numberized: bool = False,
) -> str:
    output_path = maybe_generate_temp_file_path(output_path)
    if use_char_data:
        dataset = char_data.InMemoryNumpyWordCharDataset()
        dataset.parse(
            path=text_file,
            word_dict=dictionary,
            char_dict=char_dictionary,
            reverse_order=reverse_order,
            append_eos=append_eos,
        )
    else:
        dataset = pytorch_translate_data.InMemoryNumpyDataset()
        dataset.parse(
            path=text_file,
            dictionary=dictionary,
            reverse_order=reverse_order,
            append_eos=append_eos,
            already_numberized=already_numberized,
        )
    dataset.save(output_path)
    return output_path
コード例 #4
0
def binarize_text_file(
    text_file: str,
    dictionary: Dictionary,
    output_path: str,
    append_eos: bool,
    reverse_order: bool,
    use_char_data: bool = False,
    char_dictionary: Optional[Dictionary] = None,
) -> str:
    if not output_path:
        fd, output_path = tempfile.mkstemp()
        # We only need the file name.
        os.close(fd)

    # numpy silently appends this suffix if it is not present, so this ensures
    # that the correct path is returned
    if not output_path.endswith(".npz"):
        output_path += ".npz"

    if use_char_data:
        dataset = char_data.InMemoryNumpyWordCharDataset()
        dataset.parse(
            path=text_file,
            word_dict=dictionary,
            char_dict=char_dictionary,
            reverse_order=reverse_order,
            append_eos=append_eos,
        )
    else:
        dataset = pytorch_translate_data.InMemoryNumpyDataset()
        dataset.parse(text_file, dictionary, reverse_order, append_eos)
    dataset.save(output_path)

    return output_path
コード例 #5
0
ファイル: preprocess.py プロジェクト: qingerVT/translate
def binarize_text_file_multilingual(
    corpus_configs: List[pytorch_translate_data.MultilingualCorpusConfig],
    output_path: str,
    append_eos: bool,
    reverse_order: bool,
    prepend_language_id: bool,
    use_char_data: bool = False,
    embed_bytes: bool = False,
    already_numberized: bool = False,
) -> str:
    output_path = maybe_generate_temp_file_path(output_path)
    if use_char_data:
        dataset = char_data.InMemoryNumpyWordCharDataset()
        dataset.parse_multilingual(
            corpus_configs,
            reverse_order=reverse_order,
            append_eos=append_eos,
            embed_bytes=embed_bytes,
            prepend_language_id=prepend_language_id,
            already_numberized=already_numberized,
        )
    else:
        dataset = pytorch_translate_data.InMemoryNumpyDataset()
        dataset.parse_multilingual(
            corpus_configs,
            append_eos=append_eos,
            reverse_order=reverse_order,
            prepend_language_id=prepend_language_id,
            already_numberized=already_numberized,
        )
    dataset.save(output_path)
    return output_path
コード例 #6
0
def binarize_text_file_multilingual(
    corpus_configs: List[pytorch_translate_data.MultilingualCorpusConfig],
    output_path: str,
    append_eos: bool,
    reverse_order: bool,
    prepend_language_id: bool,
) -> str:
    output_path = maybe_generate_temp_file_path(output_path)
    dataset = pytorch_translate_data.InMemoryNumpyDataset()
    dataset.parse_multilingual(
        corpus_configs,
        reverse_order=reverse_order,
        append_eos=append_eos,
        prepend_language_id=prepend_language_id,
    )
    dataset.save(output_path)
    return output_path
コード例 #7
0
ファイル: test_data.py プロジェクト: planb-hakone/translate
 def test_parse_oversampling(self):
     dataset = data.InMemoryNumpyDataset()
     factors = [(1, 0), (3, 2), (4, 4)]
     for o1, o2 in factors:
         corpora = [
             data.MultilingualCorpusConfig(
                 dialect_id=None,
                 data_file=self.trg_txt,
                 dict=self.d,
                 oversampling=o1,
             ),
             data.MultilingualCorpusConfig(
                 dialect_id=None,
                 data_file=self.trg_txt,
                 dict=self.d,
                 oversampling=o2,
             ),
         ]
         dataset.parse_multilingual(corpora)
         self.assertEqual((o1 + o2) * self.num_sentences, len(dataset))
コード例 #8
0
def binarize_text_file(
    text_file: str,
    dictionary: pytorch_translate_dictionary.Dictionary,
    output_path: str,
    append_eos: bool,
    reverse_order: bool,
) -> str:
    if not output_path:
        fd, output_path = tempfile.mkstemp()
        # We only need the file name.
        os.close(fd)

    # numpy silently appends this suffix if it is not present, so this ensures
    # that the correct path is returned
    if not output_path.endswith(".npz"):
        output_path += ".npz"

    dataset = pytorch_translate_data.InMemoryNumpyDataset()
    dataset.parse(text_file, dictionary, reverse_order, append_eos)
    dataset.save(output_path)

    return output_path
コード例 #9
0
def _generate_score(models, args, task, dataset, optimize=True):
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load ensemble
    if not args.quiet:
        print("| loading model(s) from {}".format(", ".join(
            args.path.split(":"))))

    # Optimize ensemble for generation
    if optimize:
        for model in models:
            model.make_generation_fast_(
                beamable_mm_beam_size=None
                if args.no_beamable_mm else args.beam,
                need_attn=True,
            )

    translator = build_sequence_generator(args, task, models)
    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Keep track of translations
    # Initialize with empty translations
    # and zero probs scores
    translated_sentences = [""] * len(dataset)
    translated_scores = [0.0] * len(dataset)

    collect_output_hypos = getattr(args, "output_hypos_binary_path", False)
    if collect_output_hypos:
        output_hypos_token_arrays = [None] * len(dataset)

    # Generate and compute BLEU score
    dst_dict = task.target_dictionary
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(), dst_dict.unk())
    itr = get_eval_itr(args, models, task, dataset)

    oracle_scorer = None
    if args.report_oracle_bleu:
        oracle_scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(),
                                    dst_dict.unk())

    rescoring_model = rescoring.setup_rescoring(args)
    rescoring_scorer = None
    if rescoring_model:
        rescoring_scorer = bleu.Scorer(dst_dict.pad(), dst_dict.eos(),
                                       dst_dict.unk())

    num_sentences = 0
    translation_samples = []
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        gen_timer = StopwatchMeter()
        translations = translator.generate_batched_itr(
            t,
            maxlen_a=args.max_len_a,
            maxlen_b=args.max_len_b,
            cuda=use_cuda,
            timer=gen_timer,
            prefix_size=1
            if pytorch_translate_data.is_multilingual(args) else 0,
        )

        for trans_info in _iter_translations(args, task, dataset, translations,
                                             align_dict, rescoring_model):
            scorer.add(trans_info.target_tokens, trans_info.hypo_tokens)
            if oracle_scorer is not None:
                oracle_scorer.add(trans_info.target_tokens,
                                  trans_info.best_hypo_tokens)
            if rescoring_scorer is not None:
                rescoring_scorer.add(trans_info.target_tokens,
                                     trans_info.hypo_tokens_after_rescoring)

            translated_sentences[trans_info.sample_id] = trans_info.hypo_str
            translated_scores[trans_info.sample_id] = trans_info.hypo_score
            if collect_output_hypos:
                output_hypos_token_arrays[
                    trans_info.sample_id] = trans_info.best_hypo_tokens
            translation_samples.append(
                collections.OrderedDict({
                    "sample_id":
                    trans_info.sample_id.item(),
                    "src_str":
                    trans_info.src_str,
                    "target_str":
                    trans_info.target_str,
                    "hypo_str":
                    trans_info.hypo_str,
                }))
            wps_meter.update(trans_info.src_tokens.size(0))
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += 1

    # If applicable, save collected hypothesis tokens to binary output file
    if collect_output_hypos:
        output_dataset = pytorch_translate_data.InMemoryNumpyDataset()
        output_dataset.load_from_sequences(output_hypos_token_arrays)
        output_dataset.save(args.output_hypos_binary_path)

    # If applicable, save the translations to the output file
    # For eg. external evaluation
    if getattr(args, "translation_output_file", False):
        with open(args.translation_output_file, "w") as out_file:
            for hypo_str in translated_sentences:
                print(hypo_str, file=out_file)

    if getattr(args, "translation_probs_file", False):
        with open(args.translation_probs_file, "w") as out_file:
            for hypo_score in translated_scores:
                print(np.exp(hypo_score), file=out_file)

    if oracle_scorer is not None:
        print(
            f"| Oracle BLEU (best hypo in beam): {oracle_scorer.result_string()}"
        )

    if rescoring_scorer is not None:
        print(
            f"| Rescoring BLEU (top hypo in beam after rescoring):{rescoring_scorer.result_string()}"
        )

    return scorer, num_sentences, gen_timer, translation_samples