def test_finetune_extra_model_args(): args_d: dict = CHEAP_ARGS.copy() task = "summarization" tmp_dir = make_test_data_dir() args_d.update( data_dir=tmp_dir, tokenizer_name=None, train_batch_size=2, eval_batch_size=2, do_predict=False, task=task, src_lang="en_XX", tgt_lang="ro_RO", freeze_encoder=True, freeze_embeds=True, ) # test models whose config includes the extra_model_args model = BART_TINY output_dir = tempfile.mkdtemp(prefix="output_1_") args_d1 = args_d.copy() args_d1.update( model_name_or_path=model, output_dir=output_dir, ) extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") for p in extra_model_params: args_d1[p] = 0.5 args = argparse.Namespace(**args_d1) model = main(args) for p in extra_model_params: assert getattr( model.config, p) == 0.5, f"failed to override the model config for param {p}" # test models whose config doesn't include the extra_model_args model = T5_TINY output_dir = tempfile.mkdtemp(prefix="output_2_") args_d2 = args_d.copy() args_d2.update( model_name_or_path=model, output_dir=output_dir, ) unsupported_param = "encoder_layerdrop" args_d2[unsupported_param] = 0.5 args = argparse.Namespace(**args_d2) with pytest.raises(Exception) as excinfo: model = main(args) assert str( excinfo.value ) == f"model config doesn't have a `{unsupported_param}` attribute"
def test_finetune_lr_schedulers(): args_d: dict = CHEAP_ARGS.copy() task = "summarization" tmp_dir = make_test_data_dir() model = BART_TINY output_dir = tempfile.mkdtemp(prefix="output_1_") args_d.update( data_dir=tmp_dir, model_name_or_path=model, output_dir=output_dir, tokenizer_name=None, train_batch_size=2, eval_batch_size=2, do_predict=False, task=task, src_lang="en_XX", tgt_lang="ro_RO", freeze_encoder=True, freeze_embeds=True, ) # emulate finetune.py parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) args = {"--help": True} # --help test with pytest.raises(SystemExit) as excinfo: with CaptureStdout() as cs: args = parser.parse_args(args) assert False, "--help is expected to sys.exit" assert excinfo.type == SystemExit expected = lightning_base.arg_to_scheduler_metavar assert expected in cs.out, "--help is expected to list the supported schedulers" # --lr_scheduler=non_existing_scheduler test unsupported_param = "non_existing_scheduler" args = {f"--lr_scheduler={unsupported_param}"} with pytest.raises(SystemExit) as excinfo: with CaptureStderr() as cs: args = parser.parse_args(args) assert False, "invalid argument is expected to sys.exit" assert excinfo.type == SystemExit expected = f"invalid choice: '{unsupported_param}'" assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}" # --lr_scheduler=existing_scheduler test supported_param = "cosine" args_d1 = args_d.copy() args_d1["lr_scheduler"] = supported_param args = argparse.Namespace(**args_d1) model = main(args) assert getattr( model.hparams, "lr_scheduler" ) == supported_param, f"lr_scheduler={supported_param} shouldn't fail"
def test_finetune(model): args_d: dict = CHEAP_ARGS.copy() task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY ] else "summarization" args_d["label_smoothing"] = 0.1 if task == "translation" else 0 tmp_dir = make_test_data_dir() output_dir = tempfile.mkdtemp(prefix="output_") args_d.update( data_dir=tmp_dir, model_name_or_path=model, tokenizer_name=None, train_batch_size=2, eval_batch_size=2, output_dir=output_dir, do_predict=True, task=task, src_lang="en_XX", tgt_lang="ro_RO", freeze_encoder=True, freeze_embeds=True, ) assert "n_train" in args_d args = argparse.Namespace(**args_d) module = main(args) input_embeds = module.model.get_input_embeddings() assert not input_embeds.weight.requires_grad if model == T5_TINY: lm_head = module.model.lm_head assert not lm_head.weight.requires_grad assert (lm_head.weight == input_embeds.weight).all().item() elif model == FSMT_TINY: fsmt = module.model.model embed_pos = fsmt.decoder.embed_positions assert not embed_pos.weight.requires_grad assert not fsmt.decoder.embed_tokens.weight.requires_grad # check that embeds are not the same assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens else: bart = module.model.model embed_pos = bart.decoder.embed_positions assert not embed_pos.weight.requires_grad assert not bart.shared.weight.requires_grad # check that embeds are the same assert bart.decoder.embed_tokens == bart.encoder.embed_tokens assert bart.decoder.embed_tokens == bart.shared example_batch = load_json(module.output_dir / "text_batch.json") assert isinstance(example_batch, dict) assert len(example_batch) >= 4
def test_train_mbart_cc25_enro_script(self): env_vars_to_replace = { "$MAX_LEN": 64, "$BS": 64, "$GAS": 1, "$ENRO_DIR": self.data_dir, "facebook/mbart-large-cc25": MARIAN_MODEL, # "val_check_interval=0.25": "val_check_interval=1.0", "--learning_rate=3e-5": "--learning_rate 3e-4", "--num_train_epochs 6": "--num_train_epochs 1", } # Clean up bash script bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh" ).open().read().split("finetune.py")[1].strip() bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") for k, v in env_vars_to_replace.items(): bash_script = bash_script.replace(k, str(v)) output_dir = self.get_auto_remove_tmp_dir() # bash_script = bash_script.replace("--fp16 ", "") args = f""" --output_dir {output_dir} --tokenizer_name Helsinki-NLP/opus-mt-en-ro --sortish_sampler --do_predict --gpus 1 --freeze_encoder --n_train 40000 --n_val 500 --n_test 500 --fp16_opt_level O1 --num_sanity_val_steps 0 --eval_beams 2 """.split() # XXX: args.gpus > 1 : handle multigpu in the future testargs = ["finetune.py"] + bash_script.split() + args with patch.object(sys, "argv", testargs): parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = SummarizationModule.add_model_specific_args( parser, os.getcwd()) args = parser.parse_args() model = main(args) # Check metrics metrics = load_json(model.metrics_save_path) first_step_stats = metrics["val"][0] last_step_stats = metrics["val"][-1] self.assertEqual(len(metrics["val"]), (args.max_epochs / args.val_check_interval)) assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01) # model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?) self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0) # test learning requirements: # 1. BLEU improves over the course of training by more than 2 pts self.assertGreater( last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"], 2) # 2. BLEU finishes above 17 self.assertGreater(last_step_stats["val_avg_bleu"], 17) # 3. test BLEU and val BLEU within ~1.1 pt. self.assertLess( abs(metrics["val"][-1]["val_avg_bleu"] - metrics["test"][-1]["test_avg_bleu"]), 1.1) # check lightning ckpt can be loaded and has a reasonable statedict contents = os.listdir(output_dir) ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] full_path = os.path.join(args.output_dir, ckpt_path) ckpt = torch.load(full_path, map_location="cpu") expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" assert expected_key in ckpt["state_dict"] assert ckpt["state_dict"][ "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 # TODO: turn on args.do_predict when PL bug fixed. if args.do_predict: contents = {os.path.basename(p) for p in contents} assert "test_generations.txt" in contents assert "test_results.txt" in contents # assert len(metrics["val"]) == desired_n_evals assert len(metrics["test"]) == 1
def test_train_mbart_cc25_enro_script(self): data_dir = "examples/seq2seq/test_data/wmt_en_ro" env_vars_to_replace = { "--fp16_opt_level=O1": "", "$MAX_LEN": 128, "$BS": 4, "$GAS": 1, "$ENRO_DIR": data_dir, "facebook/mbart-large-cc25": MODEL_NAME, # Download is 120MB in previous test. "val_check_interval=0.25": "val_check_interval=1.0", } # Clean up bash script bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open( ).read().split("finetune.py")[1].strip() bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") for k, v in env_vars_to_replace.items(): bash_script = bash_script.replace(k, str(v)) output_dir = self.get_auto_remove_tmp_dir() bash_script = bash_script.replace("--fp16 ", "") testargs = (["finetune.py"] + bash_script.split() + [ f"--output_dir={output_dir}", "--gpus=1", "--learning_rate=3e-1", "--warmup_steps=0", "--val_check_interval=1.0", "--tokenizer_name=facebook/mbart-large-en-ro", ]) with patch.object(sys, "argv", testargs): parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = SummarizationModule.add_model_specific_args( parser, os.getcwd()) args = parser.parse_args() args.do_predict = False # assert args.gpus == gpus THIS BREAKS for multigpu model = main(args) # Check metrics metrics = load_json(model.metrics_save_path) first_step_stats = metrics["val"][0] last_step_stats = metrics["val"][-1] assert (len( metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 ) # +1 accounts for val_sanity_check assert last_step_stats["val_avg_gen_time"] >= 0.01 assert first_step_stats["val_avg_bleu"] < last_step_stats[ "val_avg_bleu"] # model learned nothing assert 1.0 >= last_step_stats[ "val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) # check lightning ckpt can be loaded and has a reasonable statedict contents = os.listdir(output_dir) ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] full_path = os.path.join(args.output_dir, ckpt_path) ckpt = torch.load(full_path, map_location="cpu") expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" assert expected_key in ckpt["state_dict"] assert ckpt["state_dict"][ "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 # TODO: turn on args.do_predict when PL bug fixed. if args.do_predict: contents = {os.path.basename(p) for p in contents} assert "test_generations.txt" in contents assert "test_results.txt" in contents # assert len(metrics["val"]) == desired_n_evals assert len(metrics["test"]) == 1