コード例 #1
0
def test_finetune_lr_schedulers():
    args_d: dict = CHEAP_ARGS.copy()

    task = "summarization"
    tmp_dir = make_test_data_dir()

    model = BART_TINY
    output_dir = tempfile.mkdtemp(prefix="output_1_")

    args_d.update(
        data_dir=tmp_dir,
        model_name_or_path=model,
        output_dir=output_dir,
        tokenizer_name=None,
        train_batch_size=2,
        eval_batch_size=2,
        do_predict=False,
        task=task,
        src_lang="en_XX",
        tgt_lang="ro_RO",
        freeze_encoder=True,
        freeze_embeds=True,
    )

    # emulate finetune.py
    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
    args = {"--help": True}

    # --help test
    with pytest.raises(SystemExit) as excinfo:
        with CaptureStdout() as cs:
            args = parser.parse_args(args)
        assert False, "--help is expected to sys.exit"
    assert excinfo.type == SystemExit
    expected = lightning_base.arg_to_scheduler_metavar
    assert expected in cs.out, "--help is expected to list the supported schedulers"

    # --lr_scheduler=non_existing_scheduler test
    unsupported_param = "non_existing_scheduler"
    args = {f"--lr_scheduler={unsupported_param}"}
    with pytest.raises(SystemExit) as excinfo:
        with CaptureStderr() as cs:
            args = parser.parse_args(args)
        assert False, "invalid argument is expected to sys.exit"
    assert excinfo.type == SystemExit
    expected = f"invalid choice: '{unsupported_param}'"
    assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"

    # --lr_scheduler=existing_scheduler test
    supported_param = "cosine"
    args_d1 = args_d.copy()
    args_d1["lr_scheduler"] = supported_param
    args = argparse.Namespace(**args_d1)
    model = main(args)
    assert getattr(
        model.hparams, "lr_scheduler"
    ) == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
コード例 #2
0
    def test_run_eval_search(self, model):
        input_file_name = Path(
            self.get_auto_remove_tmp_dir()) / "utest_input.source"
        output_file_name = input_file_name.parent / "utest_output.txt"
        assert not output_file_name.exists()

        text = {
            "en": [
                "Machine learning is great, isn't it?",
                "I like to eat bananas", "Tomorrow is another great day!"
            ],
            "de": [
                "Maschinelles Lernen ist großartig, oder?",
                "Ich esse gerne Bananen",
                "Morgen ist wieder ein toller Tag!",
            ],
        }

        tmp_dir = Path(self.get_auto_remove_tmp_dir())
        score_path = str(tmp_dir / "scores.json")
        reference_path = str(tmp_dir / "val.target")
        _dump_articles(input_file_name, text["en"])
        _dump_articles(reference_path, text["de"])
        task = "translation_en_to_de" if model == T5_TINY else "summarization"
        testargs = f"""
            run_eval_search.py
            {model}
            {str(input_file_name)}
            {str(output_file_name)}
            --score_path {score_path}
            --reference_path {reference_path}
            --task {task}
            """.split()
        testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"])

        with patch.object(sys, "argv", testargs):
            with CaptureStdout() as cs:
                run_search()
            expected_strings = [
                " num_beams | length_penalty", model, "Best score args"
            ]
            un_expected_strings = ["Info"]
            if "translation" in task:
                expected_strings.append("bleu")
            else:
                expected_strings.extend(ROUGE_KEYS)
            for w in expected_strings:
                assert w in cs.out
            for w in un_expected_strings:
                assert w not in cs.out
            assert Path(output_file_name).exists()
            os.remove(Path(output_file_name))