def test_multilingual_hybrid(self): """ Tests multilingual translation task. Important flags: `--multilingual-*-binary-path`, `--task`, `--arch`, `--source-vocabulary`, `--target-vocabulary`, `--vocabulary`. """ with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory( "test_multilingual_hybrid") as data_dir: create_dummy_multilingual_data(data_dir) train_translation_model( data_dir, [ "--task", "pytorch_translate_multilingual_task", "--arch", "multilingual_hybrid_transformer_rnn", "--encoder-embed-dim", "8", "--encoder-ffn-embed-dim", "16", "--encoder-attention-heads", "4", "--encoder-layers", "3", "--decoder-embed-dim", "8", "--decoder-attention-heads", "4", "--decoder-layers", "2", "--decoder-lstm-units", "16", "--decoder-out-embed-dim", "8", "--lang-pairs", "xh-en,zu-en", "--multilingual-train-text-file", ("xh-en:" f"{os.path.join(data_dir, 'train.xhen.xh')}," f"{os.path.join(data_dir, 'train.xhen.en')}"), "--multilingual-eval-text-file", ("xh-en:" f"{os.path.join(data_dir, 'tune.xhen.xh')}," f"{os.path.join(data_dir, 'tune.xhen.en')}"), "--multilingual-train-text-file", ("zu-en:" f"{os.path.join(data_dir, 'train.zuen.zu')}," f"{os.path.join(data_dir, 'train.zuen.en')}"), "--multilingual-eval-text-file", ("zu-en:" f"{os.path.join(data_dir, 'tune.zuen.zu')}," f"{os.path.join(data_dir, 'tune.zuen.en')}"), # set these to empty to satisfy argument validation "--train-source-text-file", "", "--train-target-text-file", "", "--eval-source-text-file", "", "--eval-target-text-file", "", ], # fairseq MultlilingualTranslationTask expects mandatory # data directory positional argument set_empty_data_positional_arg=True, set_lang_args=False, )
def test_multilingual(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory("test_multilingual") as data_dir: create_dummy_multilingual_data(data_dir) train_translation_model( data_dir, [ "--task", "pytorch_translate_multilingual", "--arch", "rnn", "--cell-type", "lstm", "--sequence-lstm", "--reverse-source", "--encoder-bidirectional", "--encoder-layers", "2", "--encoder-embed-dim", "8", "--encoder-hidden-dim", "16", "--decoder-layers", "2", "--decoder-embed-dim", "8", "--decoder-hidden-dim", "16", "--decoder-out-embed-dim", "8", "--attention-type", "dot", "--multiling-encoder-lang", "xh", "--multiling-encoder-lang", "zu", "--multiling-encoder-lang", "en", "--multiling-decoder-lang", "xh", "--multiling-decoder-lang", "en", "--multiling-source-lang", "xh", "--multiling-target-lang", "en", "--multiling-train-source-text-file", os.path.join(data_dir, "train.xhen.xh"), "--multiling-train-target-text-file", os.path.join(data_dir, "train.xhen.en"), "--multiling-eval-source-text-file", os.path.join(data_dir, "tune.xhen.xh"), "--multiling-eval-target-text-file", os.path.join(data_dir, "tune.xhen.en"), "--multiling-source-lang", "zu", "--multiling-target-lang", "en", "--multiling-train-source-text-file", os.path.join(data_dir, "train.zuen.zu"), "--multiling-train-target-text-file", os.path.join(data_dir, "train.zuen.en"), "--multiling-eval-source-text-file", os.path.join(data_dir, "tune.zuen.zu"), "--multiling-eval-target-text-file", os.path.join(data_dir, "tune.zuen.en"), "--multiling-source-lang", "en", "--multiling-target-lang", "xh", "--multiling-train-source-text-file", os.path.join(data_dir, "train.xhen.en"), "--multiling-train-target-text-file", os.path.join(data_dir, "train.xhen.xh"), "--multiling-eval-source-text-file", os.path.join(data_dir, "tune.xhen.en"), "--multiling-eval-target-text-file", os.path.join(data_dir, "tune.xhen.xh"), # set these to empty to satisfy argument validation "--train-source-text-file", "", "--train-target-text-file", "", "--eval-source-text-file", "", "--eval-target-text-file", "", ], ) for langpair, src, tgt in [ ("xhen", "xh", "en"), ("zuen", "zu", "en"), ("xhen", "en", "xh"), ]: generate_main( data_dir, [ "--task", "pytorch_translate_multilingual", "--multiling-source-lang", src, "--multiling-target-lang", tgt, "--source-vocab-file", os.path.join(data_dir, f"dictionary-src-{src}.txt"), "--target-vocab-file", os.path.join(data_dir, f"dictionary-trg-{tgt}.txt"), "--source-text-file", os.path.join(data_dir, f"tune.{langpair}.{src}"), "--target-text-file", os.path.join(data_dir, f"tune.{langpair}.{tgt}"), ], )