예제 #1
0
 def test_multilingual_hybrid(self):
     """
     Tests multilingual translation task. Important flags:
     `--multilingual-*-binary-path`, `--task`, `--arch`,
     `--source-vocabulary`, `--target-vocabulary`, `--vocabulary`.
     """
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory(
                 "test_multilingual_hybrid") as data_dir:
             create_dummy_multilingual_data(data_dir)
             train_translation_model(
                 data_dir,
                 [
                     "--task",
                     "pytorch_translate_multilingual_task",
                     "--arch",
                     "multilingual_hybrid_transformer_rnn",
                     "--encoder-embed-dim",
                     "8",
                     "--encoder-ffn-embed-dim",
                     "16",
                     "--encoder-attention-heads",
                     "4",
                     "--encoder-layers",
                     "3",
                     "--decoder-embed-dim",
                     "8",
                     "--decoder-attention-heads",
                     "4",
                     "--decoder-layers",
                     "2",
                     "--decoder-lstm-units",
                     "16",
                     "--decoder-out-embed-dim",
                     "8",
                     "--lang-pairs",
                     "xh-en,zu-en",
                     "--multilingual-train-text-file",
                     ("xh-en:"
                      f"{os.path.join(data_dir, 'train.xhen.xh')},"
                      f"{os.path.join(data_dir, 'train.xhen.en')}"),
                     "--multilingual-eval-text-file",
                     ("xh-en:"
                      f"{os.path.join(data_dir, 'tune.xhen.xh')},"
                      f"{os.path.join(data_dir, 'tune.xhen.en')}"),
                     "--multilingual-train-text-file",
                     ("zu-en:"
                      f"{os.path.join(data_dir, 'train.zuen.zu')},"
                      f"{os.path.join(data_dir, 'train.zuen.en')}"),
                     "--multilingual-eval-text-file",
                     ("zu-en:"
                      f"{os.path.join(data_dir, 'tune.zuen.zu')},"
                      f"{os.path.join(data_dir, 'tune.zuen.en')}"),
                     # set these to empty to satisfy argument validation
                     "--train-source-text-file",
                     "",
                     "--train-target-text-file",
                     "",
                     "--eval-source-text-file",
                     "",
                     "--eval-target-text-file",
                     "",
                 ],
                 # fairseq MultlilingualTranslationTask expects mandatory
                 # data directory positional argument
                 set_empty_data_positional_arg=True,
                 set_lang_args=False,
             )
예제 #2
0
 def test_multilingual(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory("test_multilingual") as data_dir:
             create_dummy_multilingual_data(data_dir)
             train_translation_model(
                 data_dir,
                 [
                     "--task",
                     "pytorch_translate_multilingual",
                     "--arch",
                     "rnn",
                     "--cell-type",
                     "lstm",
                     "--sequence-lstm",
                     "--reverse-source",
                     "--encoder-bidirectional",
                     "--encoder-layers",
                     "2",
                     "--encoder-embed-dim",
                     "8",
                     "--encoder-hidden-dim",
                     "16",
                     "--decoder-layers",
                     "2",
                     "--decoder-embed-dim",
                     "8",
                     "--decoder-hidden-dim",
                     "16",
                     "--decoder-out-embed-dim",
                     "8",
                     "--attention-type",
                     "dot",
                     "--multiling-encoder-lang",
                     "xh",
                     "--multiling-encoder-lang",
                     "zu",
                     "--multiling-encoder-lang",
                     "en",
                     "--multiling-decoder-lang",
                     "xh",
                     "--multiling-decoder-lang",
                     "en",
                     "--multiling-source-lang",
                     "xh",
                     "--multiling-target-lang",
                     "en",
                     "--multiling-train-source-text-file",
                     os.path.join(data_dir, "train.xhen.xh"),
                     "--multiling-train-target-text-file",
                     os.path.join(data_dir, "train.xhen.en"),
                     "--multiling-eval-source-text-file",
                     os.path.join(data_dir, "tune.xhen.xh"),
                     "--multiling-eval-target-text-file",
                     os.path.join(data_dir, "tune.xhen.en"),
                     "--multiling-source-lang",
                     "zu",
                     "--multiling-target-lang",
                     "en",
                     "--multiling-train-source-text-file",
                     os.path.join(data_dir, "train.zuen.zu"),
                     "--multiling-train-target-text-file",
                     os.path.join(data_dir, "train.zuen.en"),
                     "--multiling-eval-source-text-file",
                     os.path.join(data_dir, "tune.zuen.zu"),
                     "--multiling-eval-target-text-file",
                     os.path.join(data_dir, "tune.zuen.en"),
                     "--multiling-source-lang",
                     "en",
                     "--multiling-target-lang",
                     "xh",
                     "--multiling-train-source-text-file",
                     os.path.join(data_dir, "train.xhen.en"),
                     "--multiling-train-target-text-file",
                     os.path.join(data_dir, "train.xhen.xh"),
                     "--multiling-eval-source-text-file",
                     os.path.join(data_dir, "tune.xhen.en"),
                     "--multiling-eval-target-text-file",
                     os.path.join(data_dir, "tune.xhen.xh"),
                     # set these to empty to satisfy argument validation
                     "--train-source-text-file",
                     "",
                     "--train-target-text-file",
                     "",
                     "--eval-source-text-file",
                     "",
                     "--eval-target-text-file",
                     "",
                 ],
             )
             for langpair, src, tgt in [
                 ("xhen", "xh", "en"),
                 ("zuen", "zu", "en"),
                 ("xhen", "en", "xh"),
             ]:
                 generate_main(
                     data_dir,
                     [
                         "--task",
                         "pytorch_translate_multilingual",
                         "--multiling-source-lang",
                         src,
                         "--multiling-target-lang",
                         tgt,
                         "--source-vocab-file",
                         os.path.join(data_dir,
                                      f"dictionary-src-{src}.txt"),
                         "--target-vocab-file",
                         os.path.join(data_dir,
                                      f"dictionary-trg-{tgt}.txt"),
                         "--source-text-file",
                         os.path.join(data_dir, f"tune.{langpair}.{src}"),
                         "--target-text-file",
                         os.path.join(data_dir, f"tune.{langpair}.{tgt}"),
                     ],
                 )