Exemple #1
0
    def add_model_specific_args(parser, root_dir):
        SummarizationModule.add_model_specific_args(parser, root_dir)
        parser.add_argument("--teacher",
                            default="facebook/bart-large-cnn",
                            type=str)
        parser.add_argument("--alpha_ce", default=0.8, type=float)
        parser.add_argument("--alpha_mlm", default=0.2, type=float)
        # parser.add_argument("--alpha_cos", default=0.0, type=float)
        parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
        parser.add_argument("--alpha_hid",
                            default=0.0,
                            type=float,
                            required=False)
        parser.add_argument("--student_decoder_layers",
                            default=12,
                            type=int,
                            required=False)
        parser.add_argument("--student_encoder_layers",
                            default=12,
                            type=int,
                            required=False)
        parser.add_argument("--no_teacher", action="store_true", default=False)
        parser.add_argument("--length_penalty", type=float, default=-1)

        return parser
Exemple #2
0
def test_finetune_lr_schedulers():
    args_d: dict = CHEAP_ARGS.copy()

    task = "summarization"
    tmp_dir = make_test_data_dir()

    model = BART_TINY
    output_dir = tempfile.mkdtemp(prefix="output_1_")

    args_d.update(
        data_dir=tmp_dir,
        model_name_or_path=model,
        output_dir=output_dir,
        tokenizer_name=None,
        train_batch_size=2,
        eval_batch_size=2,
        do_predict=False,
        task=task,
        src_lang="en_XX",
        tgt_lang="ro_RO",
        freeze_encoder=True,
        freeze_embeds=True,
    )

    # emulate finetune.py
    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
    args = {"--help": True}

    # --help test
    with pytest.raises(SystemExit) as excinfo:
        with CaptureStdout() as cs:
            args = parser.parse_args(args)
        assert False, "--help is expected to sys.exit"
    assert excinfo.type == SystemExit
    expected = lightning_base.arg_to_scheduler_metavar
    assert expected in cs.out, "--help is expected to list the supported schedulers"

    # --lr_scheduler=non_existing_scheduler test
    unsupported_param = "non_existing_scheduler"
    args = {f"--lr_scheduler={unsupported_param}"}
    with pytest.raises(SystemExit) as excinfo:
        with CaptureStderr() as cs:
            args = parser.parse_args(args)
        assert False, "invalid argument is expected to sys.exit"
    assert excinfo.type == SystemExit
    expected = f"invalid choice: '{unsupported_param}'"
    assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"

    # --lr_scheduler=existing_scheduler test
    supported_param = "cosine"
    args_d1 = args_d.copy()
    args_d1["lr_scheduler"] = supported_param
    args = argparse.Namespace(**args_d1)
    model = main(args)
    assert getattr(
        model.hparams, "lr_scheduler"
    ) == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
Exemple #3
0
 def add_model_specific_args(parser, root_dir):
     SummarizationModule.add_model_specific_args(parser, root_dir)
     add_distill_args(parser)
     return parser
    def test_train_mbart_cc25_enro_script(self):
        env_vars_to_replace = {
            "$MAX_LEN": 64,
            "$BS": 64,
            "$GAS": 1,
            "$ENRO_DIR": self.data_dir,
            "facebook/mbart-large-cc25": MARIAN_MODEL,
            # "val_check_interval=0.25": "val_check_interval=1.0",
            "--learning_rate=3e-5": "--learning_rate 3e-4",
            "--num_train_epochs 6": "--num_train_epochs 1",
        }

        # Clean up bash script
        bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh"
                       ).open().read().split("finetune.py")[1].strip()
        bash_script = bash_script.replace("\\\n",
                                          "").strip().replace('"$@"', "")
        for k, v in env_vars_to_replace.items():
            bash_script = bash_script.replace(k, str(v))
        output_dir = self.get_auto_remove_tmp_dir()

        # bash_script = bash_script.replace("--fp16 ", "")
        args = f"""
            --output_dir {output_dir}
            --tokenizer_name Helsinki-NLP/opus-mt-en-ro
            --sortish_sampler
            --do_predict
            --gpus 1
            --freeze_encoder
            --n_train 40000
            --n_val 500
            --n_test 500
            --fp16_opt_level O1
            --num_sanity_val_steps 0
            --eval_beams 2
        """.split()
        # XXX: args.gpus > 1 : handle multigpu in the future

        testargs = ["finetune.py"] + bash_script.split() + args
        with patch.object(sys, "argv", testargs):
            parser = argparse.ArgumentParser()
            parser = pl.Trainer.add_argparse_args(parser)
            parser = SummarizationModule.add_model_specific_args(
                parser, os.getcwd())
            args = parser.parse_args()
            model = main(args)

        # Check metrics
        metrics = load_json(model.metrics_save_path)
        first_step_stats = metrics["val"][0]
        last_step_stats = metrics["val"][-1]
        self.assertEqual(len(metrics["val"]),
                         (args.max_epochs / args.val_check_interval))
        assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"],
                          float)

        self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01)
        # model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?)
        self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0)

        # test learning requirements:

        # 1. BLEU improves over the course of training by more than 2 pts
        self.assertGreater(
            last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"],
            2)

        # 2. BLEU finishes above 17
        self.assertGreater(last_step_stats["val_avg_bleu"], 17)

        # 3. test BLEU and val BLEU within ~1.1 pt.
        self.assertLess(
            abs(metrics["val"][-1]["val_avg_bleu"] -
                metrics["test"][-1]["test_avg_bleu"]), 1.1)

        # check lightning ckpt can be loaded and has a reasonable statedict
        contents = os.listdir(output_dir)
        ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
        full_path = os.path.join(args.output_dir, ckpt_path)
        ckpt = torch.load(full_path, map_location="cpu")
        expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
        assert expected_key in ckpt["state_dict"]
        assert ckpt["state_dict"][
            "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32

        # TODO: turn on args.do_predict when PL bug fixed.
        if args.do_predict:
            contents = {os.path.basename(p) for p in contents}
            assert "test_generations.txt" in contents
            assert "test_results.txt" in contents
            # assert len(metrics["val"]) ==  desired_n_evals
            assert len(metrics["test"]) == 1
    def test_train_mbart_cc25_enro_script(self):
        data_dir = "examples/seq2seq/test_data/wmt_en_ro"
        env_vars_to_replace = {
            "--fp16_opt_level=O1": "",
            "$MAX_LEN": 128,
            "$BS": 4,
            "$GAS": 1,
            "$ENRO_DIR": data_dir,
            "facebook/mbart-large-cc25": MODEL_NAME,
            # Download is 120MB in previous test.
            "val_check_interval=0.25": "val_check_interval=1.0",
        }

        # Clean up bash script
        bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open(
        ).read().split("finetune.py")[1].strip()
        bash_script = bash_script.replace("\\\n",
                                          "").strip().replace('"$@"', "")
        for k, v in env_vars_to_replace.items():
            bash_script = bash_script.replace(k, str(v))
        output_dir = self.get_auto_remove_tmp_dir()

        bash_script = bash_script.replace("--fp16 ", "")
        testargs = (["finetune.py"] + bash_script.split() + [
            f"--output_dir={output_dir}",
            "--gpus=1",
            "--learning_rate=3e-1",
            "--warmup_steps=0",
            "--val_check_interval=1.0",
            "--tokenizer_name=facebook/mbart-large-en-ro",
        ])
        with patch.object(sys, "argv", testargs):
            parser = argparse.ArgumentParser()
            parser = pl.Trainer.add_argparse_args(parser)
            parser = SummarizationModule.add_model_specific_args(
                parser, os.getcwd())
            args = parser.parse_args()
            args.do_predict = False
            # assert args.gpus == gpus THIS BREAKS for multigpu
            model = main(args)

        # Check metrics
        metrics = load_json(model.metrics_save_path)
        first_step_stats = metrics["val"][0]
        last_step_stats = metrics["val"][-1]
        assert (len(
            metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1
                )  # +1 accounts for val_sanity_check

        assert last_step_stats["val_avg_gen_time"] >= 0.01

        assert first_step_stats["val_avg_bleu"] < last_step_stats[
            "val_avg_bleu"]  # model learned nothing
        assert 1.0 >= last_step_stats[
            "val_avg_gen_time"]  # model hanging on generate. Maybe bad config was saved.
        assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"],
                          float)

        # check lightning ckpt can be loaded and has a reasonable statedict
        contents = os.listdir(output_dir)
        ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
        full_path = os.path.join(args.output_dir, ckpt_path)
        ckpt = torch.load(full_path, map_location="cpu")
        expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
        assert expected_key in ckpt["state_dict"]
        assert ckpt["state_dict"][
            "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32

        # TODO: turn on args.do_predict when PL bug fixed.
        if args.do_predict:
            contents = {os.path.basename(p) for p in contents}
            assert "test_generations.txt" in contents
            assert "test_results.txt" in contents
            # assert len(metrics["val"]) ==  desired_n_evals
            assert len(metrics["test"]) == 1