예제 #1
0
def test_finetune_extra_model_args():
    args_d: dict = CHEAP_ARGS.copy()

    task = "summarization"
    tmp_dir = make_test_data_dir()

    args_d.update(
        data_dir=tmp_dir,
        tokenizer_name=None,
        train_batch_size=2,
        eval_batch_size=2,
        do_predict=False,
        task=task,
        src_lang="en_XX",
        tgt_lang="ro_RO",
        freeze_encoder=True,
        freeze_embeds=True,
    )

    # test models whose config includes the extra_model_args
    model = BART_TINY
    output_dir = tempfile.mkdtemp(prefix="output_1_")
    args_d1 = args_d.copy()
    args_d1.update(
        model_name_or_path=model,
        output_dir=output_dir,
    )
    extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout",
                          "attention_dropout")
    for p in extra_model_params:
        args_d1[p] = 0.5
    args = argparse.Namespace(**args_d1)
    model = main(args)
    for p in extra_model_params:
        assert getattr(
            model.config,
            p) == 0.5, f"failed to override the model config for param {p}"

    # test models whose config doesn't include the extra_model_args
    model = T5_TINY
    output_dir = tempfile.mkdtemp(prefix="output_2_")
    args_d2 = args_d.copy()
    args_d2.update(
        model_name_or_path=model,
        output_dir=output_dir,
    )
    unsupported_param = "encoder_layerdrop"
    args_d2[unsupported_param] = 0.5
    args = argparse.Namespace(**args_d2)
    with pytest.raises(Exception) as excinfo:
        model = main(args)
    assert str(
        excinfo.value
    ) == f"model config doesn't have a `{unsupported_param}` attribute"
예제 #2
0
def test_finetune_lr_schedulers():
    args_d: dict = CHEAP_ARGS.copy()

    task = "summarization"
    tmp_dir = make_test_data_dir()

    model = BART_TINY
    output_dir = tempfile.mkdtemp(prefix="output_1_")

    args_d.update(
        data_dir=tmp_dir,
        model_name_or_path=model,
        output_dir=output_dir,
        tokenizer_name=None,
        train_batch_size=2,
        eval_batch_size=2,
        do_predict=False,
        task=task,
        src_lang="en_XX",
        tgt_lang="ro_RO",
        freeze_encoder=True,
        freeze_embeds=True,
    )

    # emulate finetune.py
    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
    args = {"--help": True}

    # --help test
    with pytest.raises(SystemExit) as excinfo:
        with CaptureStdout() as cs:
            args = parser.parse_args(args)
        assert False, "--help is expected to sys.exit"
    assert excinfo.type == SystemExit
    expected = lightning_base.arg_to_scheduler_metavar
    assert expected in cs.out, "--help is expected to list the supported schedulers"

    # --lr_scheduler=non_existing_scheduler test
    unsupported_param = "non_existing_scheduler"
    args = {f"--lr_scheduler={unsupported_param}"}
    with pytest.raises(SystemExit) as excinfo:
        with CaptureStderr() as cs:
            args = parser.parse_args(args)
        assert False, "invalid argument is expected to sys.exit"
    assert excinfo.type == SystemExit
    expected = f"invalid choice: '{unsupported_param}'"
    assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"

    # --lr_scheduler=existing_scheduler test
    supported_param = "cosine"
    args_d1 = args_d.copy()
    args_d1["lr_scheduler"] = supported_param
    args = argparse.Namespace(**args_d1)
    model = main(args)
    assert getattr(
        model.hparams, "lr_scheduler"
    ) == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
예제 #3
0
def test_finetune(model):
    args_d: dict = CHEAP_ARGS.copy()
    task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY
                                      ] else "summarization"
    args_d["label_smoothing"] = 0.1 if task == "translation" else 0

    tmp_dir = make_test_data_dir()
    output_dir = tempfile.mkdtemp(prefix="output_")
    args_d.update(
        data_dir=tmp_dir,
        model_name_or_path=model,
        tokenizer_name=None,
        train_batch_size=2,
        eval_batch_size=2,
        output_dir=output_dir,
        do_predict=True,
        task=task,
        src_lang="en_XX",
        tgt_lang="ro_RO",
        freeze_encoder=True,
        freeze_embeds=True,
    )
    assert "n_train" in args_d
    args = argparse.Namespace(**args_d)
    module = main(args)

    input_embeds = module.model.get_input_embeddings()
    assert not input_embeds.weight.requires_grad
    if model == T5_TINY:
        lm_head = module.model.lm_head
        assert not lm_head.weight.requires_grad
        assert (lm_head.weight == input_embeds.weight).all().item()
    elif model == FSMT_TINY:
        fsmt = module.model.model
        embed_pos = fsmt.decoder.embed_positions
        assert not embed_pos.weight.requires_grad
        assert not fsmt.decoder.embed_tokens.weight.requires_grad
        # check that embeds are not the same
        assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
    else:
        bart = module.model.model
        embed_pos = bart.decoder.embed_positions
        assert not embed_pos.weight.requires_grad
        assert not bart.shared.weight.requires_grad
        # check that embeds are the same
        assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
        assert bart.decoder.embed_tokens == bart.shared

    example_batch = load_json(module.output_dir / "text_batch.json")
    assert isinstance(example_batch, dict)
    assert len(example_batch) >= 4
    def test_train_mbart_cc25_enro_script(self):
        env_vars_to_replace = {
            "$MAX_LEN": 64,
            "$BS": 64,
            "$GAS": 1,
            "$ENRO_DIR": self.data_dir,
            "facebook/mbart-large-cc25": MARIAN_MODEL,
            # "val_check_interval=0.25": "val_check_interval=1.0",
            "--learning_rate=3e-5": "--learning_rate 3e-4",
            "--num_train_epochs 6": "--num_train_epochs 1",
        }

        # Clean up bash script
        bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh"
                       ).open().read().split("finetune.py")[1].strip()
        bash_script = bash_script.replace("\\\n",
                                          "").strip().replace('"$@"', "")
        for k, v in env_vars_to_replace.items():
            bash_script = bash_script.replace(k, str(v))
        output_dir = self.get_auto_remove_tmp_dir()

        # bash_script = bash_script.replace("--fp16 ", "")
        args = f"""
            --output_dir {output_dir}
            --tokenizer_name Helsinki-NLP/opus-mt-en-ro
            --sortish_sampler
            --do_predict
            --gpus 1
            --freeze_encoder
            --n_train 40000
            --n_val 500
            --n_test 500
            --fp16_opt_level O1
            --num_sanity_val_steps 0
            --eval_beams 2
        """.split()
        # XXX: args.gpus > 1 : handle multigpu in the future

        testargs = ["finetune.py"] + bash_script.split() + args
        with patch.object(sys, "argv", testargs):
            parser = argparse.ArgumentParser()
            parser = pl.Trainer.add_argparse_args(parser)
            parser = SummarizationModule.add_model_specific_args(
                parser, os.getcwd())
            args = parser.parse_args()
            model = main(args)

        # Check metrics
        metrics = load_json(model.metrics_save_path)
        first_step_stats = metrics["val"][0]
        last_step_stats = metrics["val"][-1]
        self.assertEqual(len(metrics["val"]),
                         (args.max_epochs / args.val_check_interval))
        assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"],
                          float)

        self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01)
        # model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?)
        self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0)

        # test learning requirements:

        # 1. BLEU improves over the course of training by more than 2 pts
        self.assertGreater(
            last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"],
            2)

        # 2. BLEU finishes above 17
        self.assertGreater(last_step_stats["val_avg_bleu"], 17)

        # 3. test BLEU and val BLEU within ~1.1 pt.
        self.assertLess(
            abs(metrics["val"][-1]["val_avg_bleu"] -
                metrics["test"][-1]["test_avg_bleu"]), 1.1)

        # check lightning ckpt can be loaded and has a reasonable statedict
        contents = os.listdir(output_dir)
        ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
        full_path = os.path.join(args.output_dir, ckpt_path)
        ckpt = torch.load(full_path, map_location="cpu")
        expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
        assert expected_key in ckpt["state_dict"]
        assert ckpt["state_dict"][
            "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32

        # TODO: turn on args.do_predict when PL bug fixed.
        if args.do_predict:
            contents = {os.path.basename(p) for p in contents}
            assert "test_generations.txt" in contents
            assert "test_results.txt" in contents
            # assert len(metrics["val"]) ==  desired_n_evals
            assert len(metrics["test"]) == 1
    def test_train_mbart_cc25_enro_script(self):
        data_dir = "examples/seq2seq/test_data/wmt_en_ro"
        env_vars_to_replace = {
            "--fp16_opt_level=O1": "",
            "$MAX_LEN": 128,
            "$BS": 4,
            "$GAS": 1,
            "$ENRO_DIR": data_dir,
            "facebook/mbart-large-cc25": MODEL_NAME,
            # Download is 120MB in previous test.
            "val_check_interval=0.25": "val_check_interval=1.0",
        }

        # Clean up bash script
        bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open(
        ).read().split("finetune.py")[1].strip()
        bash_script = bash_script.replace("\\\n",
                                          "").strip().replace('"$@"', "")
        for k, v in env_vars_to_replace.items():
            bash_script = bash_script.replace(k, str(v))
        output_dir = self.get_auto_remove_tmp_dir()

        bash_script = bash_script.replace("--fp16 ", "")
        testargs = (["finetune.py"] + bash_script.split() + [
            f"--output_dir={output_dir}",
            "--gpus=1",
            "--learning_rate=3e-1",
            "--warmup_steps=0",
            "--val_check_interval=1.0",
            "--tokenizer_name=facebook/mbart-large-en-ro",
        ])
        with patch.object(sys, "argv", testargs):
            parser = argparse.ArgumentParser()
            parser = pl.Trainer.add_argparse_args(parser)
            parser = SummarizationModule.add_model_specific_args(
                parser, os.getcwd())
            args = parser.parse_args()
            args.do_predict = False
            # assert args.gpus == gpus THIS BREAKS for multigpu
            model = main(args)

        # Check metrics
        metrics = load_json(model.metrics_save_path)
        first_step_stats = metrics["val"][0]
        last_step_stats = metrics["val"][-1]
        assert (len(
            metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1
                )  # +1 accounts for val_sanity_check

        assert last_step_stats["val_avg_gen_time"] >= 0.01

        assert first_step_stats["val_avg_bleu"] < last_step_stats[
            "val_avg_bleu"]  # model learned nothing
        assert 1.0 >= last_step_stats[
            "val_avg_gen_time"]  # model hanging on generate. Maybe bad config was saved.
        assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"],
                          float)

        # check lightning ckpt can be loaded and has a reasonable statedict
        contents = os.listdir(output_dir)
        ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
        full_path = os.path.join(args.output_dir, ckpt_path)
        ckpt = torch.load(full_path, map_location="cpu")
        expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
        assert expected_key in ckpt["state_dict"]
        assert ckpt["state_dict"][
            "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32

        # TODO: turn on args.do_predict when PL bug fixed.
        if args.do_predict:
            contents = {os.path.basename(p) for p in contents}
            assert "test_generations.txt" in contents
            assert "test_results.txt" in contents
            # assert len(metrics["val"]) ==  desired_n_evals
            assert len(metrics["test"]) == 1