def test_finetune_lr_schedulers(): args_d: dict = CHEAP_ARGS.copy() task = "summarization" tmp_dir = make_test_data_dir() model = BART_TINY output_dir = tempfile.mkdtemp(prefix="output_1_") args_d.update( data_dir=tmp_dir, model_name_or_path=model, output_dir=output_dir, tokenizer_name=None, train_batch_size=2, eval_batch_size=2, do_predict=False, task=task, src_lang="en_XX", tgt_lang="ro_RO", freeze_encoder=True, freeze_embeds=True, ) # emulate finetune.py parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) args = {"--help": True} # --help test with pytest.raises(SystemExit) as excinfo: with CaptureStdout() as cs: args = parser.parse_args(args) assert False, "--help is expected to sys.exit" assert excinfo.type == SystemExit expected = lightning_base.arg_to_scheduler_metavar assert expected in cs.out, "--help is expected to list the supported schedulers" # --lr_scheduler=non_existing_scheduler test unsupported_param = "non_existing_scheduler" args = {f"--lr_scheduler={unsupported_param}"} with pytest.raises(SystemExit) as excinfo: with CaptureStderr() as cs: args = parser.parse_args(args) assert False, "invalid argument is expected to sys.exit" assert excinfo.type == SystemExit expected = f"invalid choice: '{unsupported_param}'" assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}" # --lr_scheduler=existing_scheduler test supported_param = "cosine" args_d1 = args_d.copy() args_d1["lr_scheduler"] = supported_param args = argparse.Namespace(**args_d1) model = main(args) assert getattr( model.hparams, "lr_scheduler" ) == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
def test_run_eval_search(self, model): input_file_name = Path( self.get_auto_remove_tmp_dir()) / "utest_input.source" output_file_name = input_file_name.parent / "utest_output.txt" assert not output_file_name.exists() text = { "en": [ "Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!" ], "de": [ "Maschinelles Lernen ist großartig, oder?", "Ich esse gerne Bananen", "Morgen ist wieder ein toller Tag!", ], } tmp_dir = Path(self.get_auto_remove_tmp_dir()) score_path = str(tmp_dir / "scores.json") reference_path = str(tmp_dir / "val.target") _dump_articles(input_file_name, text["en"]) _dump_articles(reference_path, text["de"]) task = "translation_en_to_de" if model == T5_TINY else "summarization" testargs = f""" run_eval_search.py {model} {str(input_file_name)} {str(output_file_name)} --score_path {score_path} --reference_path {reference_path} --task {task} """.split() testargs.extend(["--search", "num_beams=1:2 length_penalty=0.9:1.0"]) with patch.object(sys, "argv", testargs): with CaptureStdout() as cs: run_search() expected_strings = [ " num_beams | length_penalty", model, "Best score args" ] un_expected_strings = ["Info"] if "translation" in task: expected_strings.append("bleu") else: expected_strings.extend(ROUGE_KEYS) for w in expected_strings: assert w in cs.out for w in un_expected_strings: assert w not in cs.out assert Path(output_file_name).exists() os.remove(Path(output_file_name))