コード例 #1
0
 def test_quantization(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory("test_quantization") as data_dir:
             create_dummy_data(data_dir)
             preprocess_lm_data(data_dir)
             # tests both scalar and iterative PQ quantization
             _quantize_language_model(data_dir, "transformer_lm")
コード例 #2
0
 def test_fp16(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory("test_fp16") as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             train_translation_model(data_dir, "fconv_iwslt_de_en", ["--fp16"])
             generate_main(data_dir)
コード例 #3
0
 def test_optimizers(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_optimizers') as data_dir:
             # Use just a bit of data and tiny model to keep this test runtime reasonable
             create_dummy_data(data_dir, num_examples=10, maxlen=5)
             preprocess_translation_data(data_dir)
             optimizers = [
                 'adafactor', 'adam', 'nag', 'adagrad', 'sgd', 'adadelta'
             ]
             last_checkpoint = os.path.join(data_dir, 'checkpoint_last.pt')
             for optimizer in optimizers:
                 if os.path.exists(last_checkpoint):
                     os.remove(last_checkpoint)
                 train_translation_model(data_dir, 'lstm', [
                     '--required-batch-size-multiple',
                     '1',
                     '--encoder-layers',
                     '1',
                     '--encoder-hidden-size',
                     '32',
                     '--decoder-layers',
                     '1',
                     '--optimizer',
                     optimizer,
                 ])
                 generate_main(data_dir)
コード例 #4
0
 def test_levenshtein_transformer(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory(
             "test_levenshtein_transformer"
         ) as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir, ["--joined-dictionary"])
             train_translation_model(
                 data_dir,
                 "levenshtein_transformer",
                 [
                     "--apply-bert-init",
                     "--early-exit",
                     "6,6,6",
                     "--criterion",
                     "nat_loss",
                 ],
                 task="translation_lev",
             )
             generate_main(
                 data_dir,
                 [
                     "--task",
                     "translation_lev",
                     "--iter-decode-max-iter",
                     "9",
                     "--iter-decode-eos-penalty",
                     "0",
                     "--print-step",
                 ],
             )
コード例 #5
0
 def test_iterative_nonautoregressive_transformer(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory(
                 'test_iterative_nonautoregressive_transformer'
         ) as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir, ['--joined-dictionary'])
             train_translation_model(
                 data_dir,
                 'iterative_nonautoregressive_transformer', [
                     '--apply-bert-init', '--src-embedding-copy',
                     '--criterion', 'nat_loss', '--noise', 'full_mask',
                     '--stochastic-approx', '--dae-ratio', '0.5',
                     '--train-step', '3'
                 ],
                 task='translation_lev')
             generate_main(data_dir, [
                 '--task',
                 'translation_lev',
                 '--iter-decode-max-iter',
                 '9',
                 '--iter-decode-eos-penalty',
                 '0',
                 '--print-step',
             ])
コード例 #6
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def test_generation(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_sampling') as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             train_translation_model(data_dir, 'fconv_iwslt_de_en')
             generate_main(data_dir, [
                 '--sampling',
                 '--temperature', '2',
                 '--beam', '2',
                 '--nbest', '2',
             ])
             generate_main(data_dir, [
                 '--sampling',
                 '--sampling-topk', '3',
                 '--beam', '2',
                 '--nbest', '2',
             ])
             generate_main(data_dir, [
                 '--sampling',
                 '--sampling-topp', '0.2',
                 '--beam', '2',
                 '--nbest', '2',
             ])
             generate_main(data_dir, [
                 '--diversity-rate', '0.5',
                 '--beam', '6',
             ])
             with self.assertRaises(ValueError):
                 generate_main(data_dir, [
                     '--diverse-beam-groups', '4',
                     '--match-source-len',
                 ])
             generate_main(data_dir, ['--prefix-size', '2'])
             generate_main(data_dir, ['--retain-dropout'])
コード例 #7
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def test_update_freq(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_update_freq') as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--update-freq', '3'])
             generate_main(data_dir)
コード例 #8
0
 def test_mixture_of_experts(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_moe') as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             train_translation_model(data_dir, 'transformer_iwslt_de_en', [
                 '--task',
                 'translation_moe',
                 '--user-dir',
                 'examples/translation_moe/src',
                 '--method',
                 'hMoElp',
                 '--mean-pool-gating-network',
                 '--num-experts',
                 '3',
                 '--encoder-layers',
                 '2',
                 '--decoder-layers',
                 '2',
                 '--encoder-embed-dim',
                 '8',
                 '--decoder-embed-dim',
                 '8',
             ])
             generate_main(data_dir, [
                 '--task', 'translation_moe', '--user-dir',
                 'examples/translation_moe/src', '--method', 'hMoElp',
                 '--mean-pool-gating-network', '--num-experts', '3',
                 '--gen-expert', '0'
             ])
コード例 #9
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def test_raw(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_fconv_raw') as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir, ['--dataset-impl', 'raw'])
             train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--dataset-impl', 'raw'])
             generate_main(data_dir, ['--dataset-impl', 'raw'])
コード例 #10
0
 def test_transformer_pointer_generator(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory(
                 'test_transformer_pointer_generator') as data_dir:
             create_dummy_data(data_dir)
             preprocess_summarization_data(data_dir)
             train_translation_model(data_dir,
                                     'transformer_pointer_generator', [
                                         '--user-dir',
                                         'examples/pointer_generator/src',
                                         '--encoder-layers',
                                         '2',
                                         '--decoder-layers',
                                         '2',
                                         '--encoder-embed-dim',
                                         '8',
                                         '--decoder-embed-dim',
                                         '8',
                                         '--alignment-layer',
                                         '-1',
                                         '--alignment-heads',
                                         '1',
                                         '--source-position-markers',
                                         '0',
                                     ],
                                     run_validation=True)
             generate_main(data_dir)
コード例 #11
0
    def test_fconv_self_att_wp(self):
        with contextlib.redirect_stdout(StringIO()):
            with tempfile.TemporaryDirectory(
                    'test_fconv_self_att_wp') as data_dir:
                create_dummy_data(data_dir)
                preprocess_translation_data(data_dir)
                config = [
                    '--encoder-layers', '[(128, 3)] * 2', '--decoder-layers',
                    '[(128, 3)] * 2', '--decoder-attention', 'True',
                    '--encoder-attention', 'False', '--gated-attention',
                    'True', '--self-attention', 'True', '--project-input',
                    'True', '--encoder-embed-dim', '8', '--decoder-embed-dim',
                    '8', '--decoder-out-embed-dim', '8',
                    '--multihead-self-attention-nheads', '2'
                ]
                train_translation_model(data_dir, 'fconv_self_att_wp', config)
                generate_main(data_dir)

                # fusion model
                os.rename(os.path.join(data_dir, 'checkpoint_last.pt'),
                          os.path.join(data_dir, 'pretrained.pt'))
                config.extend([
                    '--pretrained',
                    'True',
                    '--pretrained-checkpoint',
                    os.path.join(data_dir, 'pretrained.pt'),
                    '--save-dir',
                    os.path.join(data_dir, 'fusion_model'),
                ])
                train_translation_model(data_dir, 'fconv_self_att_wp', config)
コード例 #12
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def test_multilingual_transformer(self):
     # test with all combinations of encoder/decoder lang tokens
     encoder_langtok_flags = [[], ['--encoder-langtok', 'src'], ['--encoder-langtok', 'tgt']]
     decoder_langtok_flags = [[], ['--decoder-langtok']]
     with contextlib.redirect_stdout(StringIO()):
         for i in range(len(encoder_langtok_flags)):
             for j in range(len(decoder_langtok_flags)):
                 enc_ltok_flag = encoder_langtok_flags[i]
                 dec_ltok_flag = decoder_langtok_flags[j]
                 with tempfile.TemporaryDirectory(f'test_multilingual_transformer_{i}_{j}') as data_dir:
                     create_dummy_data(data_dir)
                     preprocess_translation_data(data_dir)
                     train_translation_model(
                         data_dir,
                         arch='multilingual_transformer',
                         task='multilingual_translation',
                         extra_flags=[
                             '--encoder-layers', '2',
                             '--decoder-layers', '2',
                             '--encoder-embed-dim', '8',
                             '--decoder-embed-dim', '8',
                         ] + enc_ltok_flag + dec_ltok_flag,
                         lang_flags=['--lang-pairs', 'in-out,out-in'],
                         run_validation=True,
                         extra_valid_flags=enc_ltok_flag + dec_ltok_flag,
                     )
                     generate_main(
                         data_dir,
                         extra_flags=[
                             '--task', 'multilingual_translation',
                             '--lang-pairs', 'in-out,out-in',
                             '--source-lang', 'in',
                             '--target-lang', 'out',
                         ] + enc_ltok_flag + dec_ltok_flag,
                     )
コード例 #13
0
 def test_max_positions(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_max_positions') as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             with self.assertRaises(Exception) as context:
                 train_translation_model(
                     data_dir,
                     'fconv_iwslt_de_en',
                     ['--max-target-positions', '5'],
                 )
             self.assertTrue(
                 'skip this example with --skip-invalid-size-inputs-valid-test'
                 in str(context.exception))
             train_translation_model(
                 data_dir,
                 'fconv_iwslt_de_en',
                 [
                     '--max-target-positions', '5',
                     '--skip-invalid-size-inputs-valid-test'
                 ],
             )
             with self.assertRaises(Exception) as context:
                 generate_main(data_dir)
             generate_main(data_dir,
                           ['--skip-invalid-size-inputs-valid-test'])
コード例 #14
0
 def test_roberta_masked_lm(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory("test_roberta_mlm") as data_dir:
             create_dummy_data(data_dir)
             preprocess_lm_data(data_dir)
             train_masked_lm(data_dir,
                             "roberta_base",
                             extra_flags=["--encoder-layers", "2"])
コード例 #15
0
 def test_training_lm_plasma(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
             create_dummy_data(data_dir)
             preprocess_lm_data(data_dir)
             train_language_model(
                 data_dir,
                 "transformer_lm",
                 ["--use-plasma-view", "--plasma-path", self.path],
                 run_validation=True,
             )
コード例 #16
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def test_dynamicconv(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_dynamicconv') as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             train_translation_model(data_dir, 'lightconv_iwslt_de_en', [
                 '--encoder-conv-type', 'dynamic',
                 '--decoder-conv-type', 'dynamic',
                 '--encoder-embed-dim', '8',
                 '--decoder-embed-dim', '8',
             ])
             generate_main(data_dir)
コード例 #17
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def test_transformer(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_transformer') as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             train_translation_model(data_dir, 'transformer_iwslt_de_en', [
                 '--encoder-layers', '2',
                 '--decoder-layers', '2',
                 '--encoder-embed-dim', '8',
                 '--decoder-embed-dim', '8',
             ], run_validation=True)
             generate_main(data_dir)
コード例 #18
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def test_eval_bleu(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_eval_bleu') as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             train_translation_model(data_dir, 'fconv_iwslt_de_en', [
                 '--eval-bleu',
                 '--eval-bleu-print-samples',
                 '--eval-bleu-remove-bpe',
                 '--eval-bleu-detok', 'space',
                 '--eval-bleu-args', '{"beam": 4, "min_len": 10}',
             ])
コード例 #19
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_only):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory("test_mlm") as data_dir:
             create_dummy_data(data_dir)
             preprocess_lm_data(data_dir)
             train_legacy_masked_language_model(
                 data_dir,
                 arch="masked_lm",
                 extra_args=('--encoder-learned-pos',) if learned_pos_emb else ()
             )
             with tempfile.TemporaryDirectory(
                 "test_mlm_translation"
             ) as translation_dir:
                 create_dummy_data(translation_dir)
                 preprocess_translation_data(
                     translation_dir, extra_flags=["--joined-dictionary"]
                 )
                 # Train transformer with data_dir/checkpoint_last.pt
                 train_translation_model(
                     translation_dir,
                     arch="transformer_from_pretrained_xlm",
                     extra_flags=[
                         "--decoder-layers",
                         "1",
                         "--decoder-embed-dim",
                         "32",
                         "--decoder-attention-heads",
                         "1",
                         "--decoder-ffn-embed-dim",
                         "32",
                         "--encoder-layers",
                         "1",
                         "--encoder-embed-dim",
                         "32",
                         "--encoder-attention-heads",
                         "1",
                         "--encoder-ffn-embed-dim",
                         "32",
                         "--pretrained-xlm-checkpoint",
                         "{}/checkpoint_last.pt".format(data_dir),
                         "--activation-fn",
                         "gelu",
                         "--max-source-positions",
                         "500",
                         "--max-target-positions",
                         "500",
                     ] + (
                         ["--encoder-learned-pos", "--decoder-learned-pos"]
                         if learned_pos_emb else []
                     ) + (['--init-encoder-only'] if encoder_only else []),
                     task="translation_from_pretrained_xlm",
                 )
コード例 #20
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def test_lstm(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_lstm') as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             train_translation_model(data_dir, 'lstm_wiseman_iwslt_de_en', [
                 '--encoder-layers', '2',
                 '--decoder-layers', '2',
                 '--encoder-embed-dim', '8',
                 '--decoder-embed-dim', '8',
                 '--decoder-out-embed-dim', '8',
             ])
             generate_main(data_dir)
コード例 #21
0
 def test_transformer_relative_positional_embeddings(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_transformer') as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             train_translation_model(
                 data_dir,
                 'transformer_rel_pos_embeddings', [
                     '--encoder-layers', '2', '--decoder-layers', '2',
                     '--encoder-embed-dim', '8', '--decoder-embed-dim', '8',
                     '--max-relative-pos', '5'
                 ],
                 run_validation=True)
             generate_main(data_dir)
コード例 #22
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def test_lstm_lm_residuals(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_lstm_lm_residuals') as data_dir:
             create_dummy_data(data_dir)
             preprocess_lm_data(data_dir)
             train_language_model(
                 data_dir, 'lstm_lm', ['--add-bos-token', '--residuals'], run_validation=True,
             )
             eval_lm_main(data_dir)
             generate_main(data_dir, [
                 '--task', 'language_modeling',
                 '--sample-break-mode', 'eos',
                 '--tokens-per-sample', '500',
             ])
コード例 #23
0
 def test_fp16_multigpu(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory("test_fp16") as data_dir:
             log = os.path.join(data_dir, "train.log")
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             train_translation_model(
                 data_dir,
                 "fconv_iwslt_de_en",
                 ["--fp16", "--log-file", log],
                 world_size=min(torch.cuda.device_count(), 2),
             )
             generate_main(data_dir)
             assert os.path.exists(log)
コード例 #24
0
    def test_resume_training(self):
        flags = [
            "--fp16",
            "--log-format",
            "json",
            "--max-update",
            "10",
            "--save-interval-updates",
            "2",
            "--log-interval",
            "1",
            "--log-file",
        ]
        world_size = min(torch.cuda.device_count(), 2)
        arch = "fconv_iwslt_de_en"
        with contextlib.redirect_stdout(StringIO()):
            with tempfile.TemporaryDirectory("test_fp16") as data_dir:
                log = os.path.join(data_dir, "train.log")
                create_dummy_data(data_dir)
                preprocess_translation_data(data_dir)
                train_translation_model(
                    data_dir,
                    arch,
                    flags + [log],
                    world_size=world_size,
                )
                log2 = os.path.join(data_dir, "resume.log")
                restore_file = os.path.join(data_dir, "checkpoint_1_2.pt")
                assert os.path.exists(
                    restore_file
                ), f"{restore_file} not written. Choices: {os.listdir(data_dir)}"
                train_translation_model(
                    data_dir,
                    arch,
                    flags + [log2, "--restore-file", restore_file],
                    world_size=world_size,
                )

                l1 = self.parse_logs(log)
                l2 = self.parse_logs(log2)
                assert int(l2[0]["num_updates"]) == 3, f"{l1}\n\n {l2}"
                for k in [
                        "train_loss",
                        "train_num_updates",
                        "train_ppl",
                        "train_gnorm",
                ]:
                    from_scratch, resumed = l1[-1][k], l2[-1][k]
                    assert (from_scratch == resumed
                            ), f"difference at {k} {from_scratch} != {resumed}"
コード例 #25
0
ファイル: test_binaries_gpu.py プロジェクト: scheiblr/fairseq
 def test_fsdp_checkpoint_generate(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory("test_fsdp_sharded") as data_dir:
             log = os.path.join(data_dir, "train.log")
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             world_size = min(torch.cuda.device_count(), 2)
             train_translation_model(
                 data_dir,
                 "fconv_iwslt_de_en",
                 ["--log-file", log, "--ddp-backend", "fully_sharded"],
                 world_size=world_size,
             )
             generate_main(data_dir)
             assert os.path.exists(log)
コード例 #26
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def test_lstm_bidirectional(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_lstm_bidirectional') as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir)
             train_translation_model(data_dir, 'lstm', [
                 '--encoder-layers', '2',
                 '--encoder-bidirectional',
                 '--encoder-hidden-size', '16',
                 '--encoder-embed-dim', '8',
                 '--decoder-embed-dim', '8',
                 '--decoder-out-embed-dim', '8',
                 '--decoder-layers', '2',
             ])
             generate_main(data_dir)
コード例 #27
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def test_insertion_transformer(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir:
             create_dummy_data(data_dir)
             preprocess_translation_data(data_dir, ['--joined-dictionary'])
             train_translation_model(data_dir, 'insertion_transformer', [
                 '--apply-bert-init', '--criterion', 'nat_loss', '--noise',
                 'random_mask'
             ], task='translation_lev')
             generate_main(data_dir, [
                 '--task', 'translation_lev',
                 '--iter-decode-max-iter', '9',
                 '--iter-decode-eos-penalty', '0',
                 '--print-step',
             ])
コード例 #28
0
ファイル: test_binaries.py プロジェクト: yf1291/nlp4
 def test_fconv_lm(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir:
             create_dummy_data(data_dir)
             preprocess_lm_data(data_dir)
             train_language_model(data_dir, 'fconv_lm', [
                 '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
                 '--decoder-embed-dim', '280',
                 '--optimizer', 'nag',
                 '--lr', '0.1',
             ])
             eval_lm_main(data_dir)
             generate_main(data_dir, [
                 '--task', 'language_modeling',
                 '--sample-break-mode', 'eos',
                 '--tokens-per-sample', '500',
             ])
コード例 #29
0
ファイル: test_binaries.py プロジェクト: zye1996/fairseq
 def test_multilingual_translation_latent_depth(self):
     # test with latent depth in encoder, decoder, or both
     encoder_latent_layer = [[], ['--encoder-latent-layer']]
     decoder_latent_layer = [[], ['--decoder-latent-layer']]
     with contextlib.redirect_stdout(StringIO()):
         for i in range(len(encoder_latent_layer)):
             for j in range(len(decoder_latent_layer)):
                 if i == 0 and j == 0:
                     continue
                 enc_ll_flag = encoder_latent_layer[i]
                 dec_ll_flag = decoder_latent_layer[j]
                 with tempfile.TemporaryDirectory(f'test_multilingual_translation_latent_depth_{i}_{j}') as data_dir:
                     create_dummy_data(data_dir)
                     preprocess_translation_data(
                         data_dir,
                         extra_flags=['--joined-dictionary']
                     )
                     train_translation_model(
                         data_dir,
                         arch='latent_multilingual_transformer',
                         task='multilingual_translation_latent_depth',
                         extra_flags=[
                             '--user-dir', 'examples/latent_depth/src',
                             '--encoder-layers', '2',
                             '--decoder-layers', '2',
                             '--encoder-embed-dim', '8',
                             '--decoder-embed-dim', '8',
                             '--share-encoders',
                             '--share-decoders',
                             '--sparsity-weight', '0.1',
                         ] + enc_ll_flag + dec_ll_flag,
                         lang_flags=['--lang-pairs', 'in-out,out-in'],
                         run_validation=True,
                         extra_valid_flags=['--user-dir', 'examples/latent_depth/src'] + enc_ll_flag + dec_ll_flag,
                     )
                     generate_main(
                         data_dir,
                         extra_flags=[
                             '--user-dir', 'examples/latent_depth/src',
                             '--task', 'multilingual_translation_latent_depth',
                             '--lang-pairs', 'in-out,out-in',
                             '--source-lang', 'in',
                             '--target-lang', 'out',
                         ] + enc_ll_flag + dec_ll_flag,
                     )
コード例 #30
0
 def test_flat_grads(self):
     with contextlib.redirect_stdout(StringIO()):
         with tempfile.TemporaryDirectory("test_flat_grads") as data_dir:
             # Use just a bit of data and tiny model to keep this test runtime reasonable
             create_dummy_data(data_dir, num_examples=10, maxlen=5)
             preprocess_translation_data(data_dir)
             with self.assertRaises(RuntimeError):
                 # adafactor isn't compatible with flat grads, which
                 # are used by default with --fp16
                 train_translation_model(
                     data_dir,
                     "lstm",
                     [
                         "--required-batch-size-multiple",
                         "1",
                         "--encoder-layers",
                         "1",
                         "--encoder-hidden-size",
                         "32",
                         "--decoder-layers",
                         "1",
                         "--optimizer",
                         "adafactor",
                         "--fp16",
                     ],
                 )
             # but it should pass once we set --fp16-no-flatten-grads
             train_translation_model(
                 data_dir,
                 "lstm",
                 [
                     "--required-batch-size-multiple",
                     "1",
                     "--encoder-layers",
                     "1",
                     "--encoder-hidden-size",
                     "32",
                     "--decoder-layers",
                     "1",
                     "--optimizer",
                     "adafactor",
                     "--fp16",
                     "--fp16-no-flatten-grads",
                 ],
             )