def add_model_specific_args(parser, root_dir): SummarizationModule.add_model_specific_args(parser, root_dir) parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str) parser.add_argument("--alpha_ce", default=0.8, type=float) parser.add_argument("--alpha_mlm", default=0.2, type=float) # parser.add_argument("--alpha_cos", default=0.0, type=float) parser.add_argument("--alpha_encoder_loss", default=0.0, type=float) parser.add_argument("--alpha_hid", default=0.0, type=float, required=False) parser.add_argument("--student_decoder_layers", default=12, type=int, required=False) parser.add_argument("--student_encoder_layers", default=12, type=int, required=False) parser.add_argument("--no_teacher", action="store_true", default=False) parser.add_argument("--length_penalty", type=float, default=-1) return parser
def test_finetune_lr_schedulers(): args_d: dict = CHEAP_ARGS.copy() task = "summarization" tmp_dir = make_test_data_dir() model = BART_TINY output_dir = tempfile.mkdtemp(prefix="output_1_") args_d.update( data_dir=tmp_dir, model_name_or_path=model, output_dir=output_dir, tokenizer_name=None, train_batch_size=2, eval_batch_size=2, do_predict=False, task=task, src_lang="en_XX", tgt_lang="ro_RO", freeze_encoder=True, freeze_embeds=True, ) # emulate finetune.py parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) args = {"--help": True} # --help test with pytest.raises(SystemExit) as excinfo: with CaptureStdout() as cs: args = parser.parse_args(args) assert False, "--help is expected to sys.exit" assert excinfo.type == SystemExit expected = lightning_base.arg_to_scheduler_metavar assert expected in cs.out, "--help is expected to list the supported schedulers" # --lr_scheduler=non_existing_scheduler test unsupported_param = "non_existing_scheduler" args = {f"--lr_scheduler={unsupported_param}"} with pytest.raises(SystemExit) as excinfo: with CaptureStderr() as cs: args = parser.parse_args(args) assert False, "invalid argument is expected to sys.exit" assert excinfo.type == SystemExit expected = f"invalid choice: '{unsupported_param}'" assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}" # --lr_scheduler=existing_scheduler test supported_param = "cosine" args_d1 = args_d.copy() args_d1["lr_scheduler"] = supported_param args = argparse.Namespace(**args_d1) model = main(args) assert getattr( model.hparams, "lr_scheduler" ) == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
def add_model_specific_args(parser, root_dir): SummarizationModule.add_model_specific_args(parser, root_dir) add_distill_args(parser) return parser
def test_train_mbart_cc25_enro_script(self): env_vars_to_replace = { "$MAX_LEN": 64, "$BS": 64, "$GAS": 1, "$ENRO_DIR": self.data_dir, "facebook/mbart-large-cc25": MARIAN_MODEL, # "val_check_interval=0.25": "val_check_interval=1.0", "--learning_rate=3e-5": "--learning_rate 3e-4", "--num_train_epochs 6": "--num_train_epochs 1", } # Clean up bash script bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh" ).open().read().split("finetune.py")[1].strip() bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") for k, v in env_vars_to_replace.items(): bash_script = bash_script.replace(k, str(v)) output_dir = self.get_auto_remove_tmp_dir() # bash_script = bash_script.replace("--fp16 ", "") args = f""" --output_dir {output_dir} --tokenizer_name Helsinki-NLP/opus-mt-en-ro --sortish_sampler --do_predict --gpus 1 --freeze_encoder --n_train 40000 --n_val 500 --n_test 500 --fp16_opt_level O1 --num_sanity_val_steps 0 --eval_beams 2 """.split() # XXX: args.gpus > 1 : handle multigpu in the future testargs = ["finetune.py"] + bash_script.split() + args with patch.object(sys, "argv", testargs): parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = SummarizationModule.add_model_specific_args( parser, os.getcwd()) args = parser.parse_args() model = main(args) # Check metrics metrics = load_json(model.metrics_save_path) first_step_stats = metrics["val"][0] last_step_stats = metrics["val"][-1] self.assertEqual(len(metrics["val"]), (args.max_epochs / args.val_check_interval)) assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01) # model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?) self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0) # test learning requirements: # 1. BLEU improves over the course of training by more than 2 pts self.assertGreater( last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"], 2) # 2. BLEU finishes above 17 self.assertGreater(last_step_stats["val_avg_bleu"], 17) # 3. test BLEU and val BLEU within ~1.1 pt. self.assertLess( abs(metrics["val"][-1]["val_avg_bleu"] - metrics["test"][-1]["test_avg_bleu"]), 1.1) # check lightning ckpt can be loaded and has a reasonable statedict contents = os.listdir(output_dir) ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] full_path = os.path.join(args.output_dir, ckpt_path) ckpt = torch.load(full_path, map_location="cpu") expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" assert expected_key in ckpt["state_dict"] assert ckpt["state_dict"][ "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 # TODO: turn on args.do_predict when PL bug fixed. if args.do_predict: contents = {os.path.basename(p) for p in contents} assert "test_generations.txt" in contents assert "test_results.txt" in contents # assert len(metrics["val"]) == desired_n_evals assert len(metrics["test"]) == 1
def test_train_mbart_cc25_enro_script(self): data_dir = "examples/seq2seq/test_data/wmt_en_ro" env_vars_to_replace = { "--fp16_opt_level=O1": "", "$MAX_LEN": 128, "$BS": 4, "$GAS": 1, "$ENRO_DIR": data_dir, "facebook/mbart-large-cc25": MODEL_NAME, # Download is 120MB in previous test. "val_check_interval=0.25": "val_check_interval=1.0", } # Clean up bash script bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open( ).read().split("finetune.py")[1].strip() bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") for k, v in env_vars_to_replace.items(): bash_script = bash_script.replace(k, str(v)) output_dir = self.get_auto_remove_tmp_dir() bash_script = bash_script.replace("--fp16 ", "") testargs = (["finetune.py"] + bash_script.split() + [ f"--output_dir={output_dir}", "--gpus=1", "--learning_rate=3e-1", "--warmup_steps=0", "--val_check_interval=1.0", "--tokenizer_name=facebook/mbart-large-en-ro", ]) with patch.object(sys, "argv", testargs): parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = SummarizationModule.add_model_specific_args( parser, os.getcwd()) args = parser.parse_args() args.do_predict = False # assert args.gpus == gpus THIS BREAKS for multigpu model = main(args) # Check metrics metrics = load_json(model.metrics_save_path) first_step_stats = metrics["val"][0] last_step_stats = metrics["val"][-1] assert (len( metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 ) # +1 accounts for val_sanity_check assert last_step_stats["val_avg_gen_time"] >= 0.01 assert first_step_stats["val_avg_bleu"] < last_step_stats[ "val_avg_bleu"] # model learned nothing assert 1.0 >= last_step_stats[ "val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) # check lightning ckpt can be loaded and has a reasonable statedict contents = os.listdir(output_dir) ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] full_path = os.path.join(args.output_dir, ckpt_path) ckpt = torch.load(full_path, map_location="cpu") expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" assert expected_key in ckpt["state_dict"] assert ckpt["state_dict"][ "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 # TODO: turn on args.do_predict when PL bug fixed. if args.do_predict: contents = {os.path.basename(p) for p in contents} assert "test_generations.txt" in contents assert "test_results.txt" in contents # assert len(metrics["val"]) == desired_n_evals assert len(metrics["test"]) == 1