Exemplo n.º 1
0
    def _test_distiller_cli(self, updates, check_contents=True):
        default_updates = dict(
            label_smoothing=0.0,
            early_stopping_patience=-1,
            train_batch_size=1,
            eval_batch_size=2,
            max_epochs=2,
            alpha_mlm=0.2,
            alpha_ce=0.8,
            do_predict=True,
            model_name_or_path="sshleifer/tinier_bart",
            teacher=CHEAP_ARGS["model_name_or_path"],
            val_check_interval=0.5,
            alpha_encoder_loss=0.4,
        )
        default_updates.update(updates)
        args_d: dict = CHEAP_ARGS.copy()
        tmp_dir = make_test_data_dir()
        output_dir = tempfile.mkdtemp(prefix="output_")

        args_d.update(data_dir=tmp_dir,
                      output_dir=output_dir,
                      **default_updates)
        model = distill_main(argparse.Namespace(**args_d))
        if not check_contents:
            return model
        contents = os.listdir(output_dir)
        contents = {os.path.basename(p) for p in contents}
        ckpt_files = [p for p in contents if p.endswith("ckpt")]
        assert len(ckpt_files) > 0

        self.assertIn("test_generations.txt", contents)
        self.assertIn("test_results.txt", contents)

        metrics = load_json(model.metrics_save_path)
        last_step_stats = metrics["val"][-1]
        self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
        self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
        self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"],
                              float)
        desired_n_evals = int(args_d["max_epochs"] *
                              (1 / args_d["val_check_interval"]) + 1)
        self.assertEqual(len(metrics["val"]), desired_n_evals)
        self.assertEqual(len(metrics["test"]), 1)
        return model
    def test_opus_mt_distill_script(self):
        data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro"
        env_vars_to_replace = {
            "--fp16_opt_level=O1": "",
            "$MAX_LEN": 128,
            "$BS": 16,
            "$GAS": 1,
            "$ENRO_DIR": data_dir,
            "$m": "sshleifer/student_marian_en_ro_6_1",
            "val_check_interval=0.25": "val_check_interval=1.0",
        }

        # Clean up bash script
        bash_script = ((self.test_file_dir / "distil_marian_no_teacher.sh"
                        ).open().read().split("distillation.py")[1].strip())
        bash_script = bash_script.replace("\\\n",
                                          "").strip().replace('"$@"', "")
        bash_script = bash_script.replace("--fp16 ", " ")

        for k, v in env_vars_to_replace.items():
            bash_script = bash_script.replace(k, str(v))
        output_dir = self.get_auto_remove_tmp_dir()
        bash_script = bash_script.replace("--fp16", "")
        epochs = 6
        testargs = (["distillation.py"] + bash_script.split() + [
            f"--output_dir={output_dir}",
            "--gpus=1",
            "--learning_rate=1e-3",
            f"--num_train_epochs={epochs}",
            "--warmup_steps=10",
            "--val_check_interval=1.0",
            "--do_predict",
        ])
        with patch.object(sys, "argv", testargs):
            parser = argparse.ArgumentParser()
            parser = pl.Trainer.add_argparse_args(parser)
            parser = BartSummarizationDistiller.add_model_specific_args(
                parser, os.getcwd())
            args = parser.parse_args()
            # assert args.gpus == gpus THIS BREAKS for multigpu

            model = distill_main(args)

        # Check metrics
        metrics = load_json(model.metrics_save_path)
        first_step_stats = metrics["val"][0]
        last_step_stats = metrics["val"][-1]
        assert len(
            metrics["val"]) >= (args.max_epochs / args.val_check_interval
                                )  # +1 accounts for val_sanity_check

        assert last_step_stats["val_avg_gen_time"] >= 0.01

        assert first_step_stats["val_avg_bleu"] < last_step_stats[
            "val_avg_bleu"]  # model learned nothing
        assert 1.0 >= last_step_stats[
            "val_avg_gen_time"]  # model hanging on generate. Maybe bad config was saved.
        assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"],
                          float)

        # check lightning ckpt can be loaded and has a reasonable statedict
        contents = os.listdir(output_dir)
        ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
        full_path = os.path.join(args.output_dir, ckpt_path)
        ckpt = torch.load(full_path, map_location="cpu")
        expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
        assert expected_key in ckpt["state_dict"]
        assert ckpt["state_dict"][
            "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32

        # TODO: turn on args.do_predict when PL bug fixed.
        if args.do_predict:
            contents = {os.path.basename(p) for p in contents}
            assert "test_generations.txt" in contents
            assert "test_results.txt" in contents
            # assert len(metrics["val"]) ==  desired_n_evals
            assert len(metrics["test"]) == 1
Exemplo n.º 3
0
    debug_args = """
--data_dir=some_data \
--src_lang=en_XX \
--tgt_lang=ro_RO \
--model_name_or_path IGNORED \
--learning_rate=3e-4 \
--train_batch_size=4 \
--eval_batch_size=4 \
--teacher Helsinki-NLP/opus-mt-en-ro \
--tokenizer_name Helsinki-NLP/opus-mt-en-ro \
--warmup_steps 500 \
--student_decoder_layers 2 --student_encoder_layers 2 \
--freeze_embeds \
--alpha_hid=3. --length_penalty=0.5 \
--gradient_accumulation_steps=2 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--output_dir=debug \
--num_train_epochs 3 \
--gpus 0 \
--do_train \
--do_predict \
--val_check_interval 0.2 \
--sortish_sampler \
    """.strip().split()
    # --fp16 \
    parser = argparse.ArgumentParser()
    parser = BartTranslationDistiller.add_model_specific_args(parser, os.getcwd())
    args = parser.parse_args(debug_args)

    distill_main(args)