def _test_distiller_cli(self, updates, check_contents=True): default_updates = dict( label_smoothing=0.0, early_stopping_patience=-1, train_batch_size=1, eval_batch_size=2, max_epochs=2, alpha_mlm=0.2, alpha_ce=0.8, do_predict=True, model_name_or_path="sshleifer/tinier_bart", teacher=CHEAP_ARGS["model_name_or_path"], val_check_interval=0.5, alpha_encoder_loss=0.4, ) default_updates.update(updates) args_d: dict = CHEAP_ARGS.copy() tmp_dir = make_test_data_dir() output_dir = tempfile.mkdtemp(prefix="output_") args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates) model = distill_main(argparse.Namespace(**args_d)) if not check_contents: return model contents = os.listdir(output_dir) contents = {os.path.basename(p) for p in contents} ckpt_files = [p for p in contents if p.endswith("ckpt")] assert len(ckpt_files) > 0 self.assertIn("test_generations.txt", contents) self.assertIn("test_results.txt", contents) metrics = load_json(model.metrics_save_path) last_step_stats = metrics["val"][-1] self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01) self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"]) self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float) desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1) self.assertEqual(len(metrics["val"]), desired_n_evals) self.assertEqual(len(metrics["test"]), 1) return model
def test_opus_mt_distill_script(self): data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro" env_vars_to_replace = { "--fp16_opt_level=O1": "", "$MAX_LEN": 128, "$BS": 16, "$GAS": 1, "$ENRO_DIR": data_dir, "$m": "sshleifer/student_marian_en_ro_6_1", "val_check_interval=0.25": "val_check_interval=1.0", } # Clean up bash script bash_script = ((self.test_file_dir / "distil_marian_no_teacher.sh" ).open().read().split("distillation.py")[1].strip()) bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") bash_script = bash_script.replace("--fp16 ", " ") for k, v in env_vars_to_replace.items(): bash_script = bash_script.replace(k, str(v)) output_dir = self.get_auto_remove_tmp_dir() bash_script = bash_script.replace("--fp16", "") epochs = 6 testargs = (["distillation.py"] + bash_script.split() + [ f"--output_dir={output_dir}", "--gpus=1", "--learning_rate=1e-3", f"--num_train_epochs={epochs}", "--warmup_steps=10", "--val_check_interval=1.0", "--do_predict", ]) with patch.object(sys, "argv", testargs): parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) parser = BartSummarizationDistiller.add_model_specific_args( parser, os.getcwd()) args = parser.parse_args() # assert args.gpus == gpus THIS BREAKS for multigpu model = distill_main(args) # Check metrics metrics = load_json(model.metrics_save_path) first_step_stats = metrics["val"][0] last_step_stats = metrics["val"][-1] assert len( metrics["val"]) >= (args.max_epochs / args.val_check_interval ) # +1 accounts for val_sanity_check assert last_step_stats["val_avg_gen_time"] >= 0.01 assert first_step_stats["val_avg_bleu"] < last_step_stats[ "val_avg_bleu"] # model learned nothing assert 1.0 >= last_step_stats[ "val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved. assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) # check lightning ckpt can be loaded and has a reasonable statedict contents = os.listdir(output_dir) ckpt_path = [x for x in contents if x.endswith(".ckpt")][0] full_path = os.path.join(args.output_dir, ckpt_path) ckpt = torch.load(full_path, map_location="cpu") expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight" assert expected_key in ckpt["state_dict"] assert ckpt["state_dict"][ "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32 # TODO: turn on args.do_predict when PL bug fixed. if args.do_predict: contents = {os.path.basename(p) for p in contents} assert "test_generations.txt" in contents assert "test_results.txt" in contents # assert len(metrics["val"]) == desired_n_evals assert len(metrics["test"]) == 1
debug_args = """ --data_dir=some_data \ --src_lang=en_XX \ --tgt_lang=ro_RO \ --model_name_or_path IGNORED \ --learning_rate=3e-4 \ --train_batch_size=4 \ --eval_batch_size=4 \ --teacher Helsinki-NLP/opus-mt-en-ro \ --tokenizer_name Helsinki-NLP/opus-mt-en-ro \ --warmup_steps 500 \ --student_decoder_layers 2 --student_encoder_layers 2 \ --freeze_embeds \ --alpha_hid=3. --length_penalty=0.5 \ --gradient_accumulation_steps=2 \ --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \ --output_dir=debug \ --num_train_epochs 3 \ --gpus 0 \ --do_train \ --do_predict \ --val_check_interval 0.2 \ --sortish_sampler \ """.strip().split() # --fp16 \ parser = argparse.ArgumentParser() parser = BartTranslationDistiller.add_model_specific_args(parser, os.getcwd()) args = parser.parse_args(debug_args) distill_main(args)