Esempio n. 1
0
def main(checkpoint_list_file, task):
    checkpoint_list_file = Path(checkpoint_list_file)
    source_points = []
    lines = checkpoint_list_file.open().read().split("\n")
    for s in lines:
        if len(s) != 0:
            if s[0] == "#":
                continue
            if str(checkpoint_list_file).endswith(".jsonl"):
                j = json.loads(s)
                s = j["meta"]["checkpoint"]["path"]
            else:
                s = s.strip()
            assert((task  + "_") in s)
            source_points.append(s)

    print(f"SOURCE POINTS: {len(source_points)}")
    for source_point in source_points:
        print(source_point)

    for source_point in source_points:
        src_path = Path(source_point)
        key = src_path.parent.name
        assert(key.startswith("hp_") or key.startswith("aws_") or key.startswith("large_"))
        dest_key = "fine_tuned_" + key
        dest_path = src_path.parent.parent.parent / f"{task}_test_final_fine_tune" / dest_key
        dest_path = dest_path.resolve()
        if dest_path.exists():
            print("SKIPPING", dest_path.name)
            continue
        else:
            print("PROCESSING", dest_path.name)

        tmp_path = Path("tmp_finetune/").resolve()
        if tmp_path.exists():
            shutil.rmtree(tmp_path)
        shutil.copytree(src_path, tmp_path)

        files_to_remove = ["trainer_state.json", "training_args.bin", "scheduler.pt"]
        for file_to_remove in files_to_remove:
            file_to_remove_ = tmp_path / file_to_remove
            if file_to_remove_.exists():
                file_to_remove_.unlink()

        dest_path.mkdir(exist_ok=True)
        with (dest_path / "source.txt").open("w") as f:
            f.write(str(src_path))

        with open(src_path / "sparse_args.json") as f:
            sparse_args = json.load(f)
            teacher = sparse_args["distil_teacher_name_or_path"]

        if task == "squad":
            QASparseXP.final_finetune(str(tmp_path), str(dest_path), teacher=teacher)
        elif task in ["mnli"]:
            GlueSparseXP.final_finetune(str(tmp_path), str(dest_path), teacher=teacher)
        else:
            raise Exception(f"Unknown task {task}")
Esempio n. 2
0
    def copy_model_files(self):
        modified = False

        src_path = self.checkpoint_path

        d = None
        try:
            if not (self.git_path / "tf_model.h5").exists() or not (
                    self.git_path / "pytorch_model.bin").exists():
                if task.startswith("squad"):
                    d = TemporaryDirectory()
                    model = QASparseXP.compile_model(src_path,
                                                     dest_path=d.name)
                    model = optimize_model(model, "heads")
                    model.save_pretrained(d.name)
                    src_path = d.name
                else:
                    raise Exception(f"Unknown task {task}")

            if not (self.git_path / "tf_model.h5").exists():
                with TemporaryDirectory() as d2:
                    if task.startswith("squad"):
                        QASparseXP.final_fine_tune_bertarize(
                            src_path, d2, remove_head_pruning=True)
                    else:
                        raise Exception(f"Unknown task {task}")

                    tf_model = TFBertForQuestionAnswering.from_pretrained(
                        d2, from_pt=True)
                    tf_model.save_pretrained(self.git_path)
                    modified = True

            if not (self.git_path / "pytorch_model.bin").exists():
                model = BertForQuestionAnswering.from_pretrained(src_path)
                model.save_pretrained(self.git_path)
                modified = True

            FILES = "special_tokens_map.json", "tokenizer_config.json", "vocab.txt"
            for file in FILES:
                if not (self.git_path / file).exists():
                    shutil.copyfile(str(Path(src_path) / file),
                                    str(self.git_path / file))
                    modified = True

        finally:
            if d is not None:
                d.cleanup()

        # Reload the config, this may have been changed by compilation / optimization (pruned_heads, gelu_patch, layer_norm_patch)
        with (self.git_path / "config.json").open() as f:
            self.checkpoint_info["config"] = json.load(f)

        return modified
Esempio n. 3
0
def final_finetune(checkpoint, model, teacher, dest, task, overwrite):
    input_info = 0
    if checkpoint is not None:
        input_info += 1
    if model is not None:
        input_info += 1
    if input_info == 2:
        raise click.Abort(
            "You should specify a checkpoint path with --checkpoint or a hub model name with --model, not both."
        )

    if input_info == 0:
        raise click.Abort(
            "You should specify a checkpoint path with --checkpoint or a hub model name with --model."
        )

    if checkpoint is not None:
        src_path = Path(checkpoint)
        dest_key = src_path.parent.name

    if model is not None:
        dest_key = model.replace("/", "_")
        src_path = model

    task_rewrite = "squadv2" if task == "squad_v2" else task
    dest_path = Path(dest).resolve(
    ) / f"{task_rewrite}_test_final_fine_tune" / ("fine_tuned_" + dest_key)

    if dest_path.exists():
        if overwrite:
            shutil.rmtree(dest_path)
        else:
            raise click.ClickException(
                f"Destination path {dest_path} already exists")
    else:
        print("PROCESSING", dest_path)

    dest_path.mkdir(exist_ok=True, parents=True)
    with (dest_path / "source.txt").open("w") as f:
        f.write(str(src_path))

    if checkpoint is not None:
        tmp_path = Path("tmp_finetune/").resolve()
        if tmp_path.exists():
            shutil.rmtree(tmp_path)
        shutil.copytree(src_path, tmp_path)
        src_path = tmp_path

        files_to_remove = [
            "trainer_state.json", "training_args.bin", "scheduler.pt"
        ]
        for file_to_remove in files_to_remove:
            file_to_remove_ = tmp_path / file_to_remove
            if file_to_remove_.exists():
                file_to_remove_.unlink()

        if teacher is None:
            with open(src_path / "sparse_args.json") as f:
                sparse_args = json.load(f)
                teacher = sparse_args["distil_teacher_name_or_path"]

    if model is not None:
        if teacher is None:
            raise click.ClickException("Please specify teacher")

    if "squad" in task:
        QASparseXP.final_finetune(str(src_path),
                                  str(dest_path),
                                  task,
                                  teacher=teacher)
    elif task in ["mnli"]:
        GlueSparseXP.final_finetune(str(src_path),
                                    str(dest_path),
                                    task,
                                    teacher=teacher)
    else:
        raise Exception(f"Unknown task {task}")
Esempio n. 4
0
    def copy_model_files(self, force=False):
        modified = False

        src_path = self.checkpoint_path

        d = None
        try:
            if force or not (self.git_path / "tf_model.h5").exists() or not (
                    self.git_path / "pytorch_model.bin").exists():
                d = TemporaryDirectory()
                if self.task in self.QA_TASKS:
                    model = QASparseXP.compile_model(src_path,
                                                     dest_path=d.name)
                elif self.task in self.GLUE_TASKS:
                    model = GlueSparseXP.compile_model(src_path,
                                                       dest_path=d.name)
                elif self.task in self.SUMMARIZATION_TASKS:
                    model = SummarizationSparseXP.compile_model(
                        src_path, dest_path=d.name)
                else:
                    raise Exception(f"Unknown task {self.task}")

                model = optimize_model(model, "heads")
                model.save_pretrained(d.name)
                src_path = d.name
            if force or not (self.git_path / "tf_model.h5").exists():
                with TemporaryDirectory() as d2:
                    if self.task in self.QA_TASKS:
                        QASparseXP.final_fine_tune_bertarize(
                            src_path, d2, remove_head_pruning=True)
                        tf_model = TFAutoModelForQuestionAnswering.from_pretrained(
                            d2, from_pt=True)
                    elif self.task in self.GLUE_TASKS:
                        GlueSparseXP.final_fine_tune_bertarize(
                            src_path, d2, remove_head_pruning=True)
                        tf_model = TFAutoModelForSequenceClassification.from_pretrained(
                            d2, from_pt=True)
                    elif self.task in self.SUMMARIZATION_TASKS:
                        SummarizationSparseXP.final_fine_tune_bertarize(
                            src_path, d2, remove_head_pruning=True)
                        tf_model = TFAutoModelForSeq2SeqLM.from_pretrained(
                            d2, from_pt=True)
                    else:
                        raise Exception(f"Unknown task {self.task}")

                    tf_model.save_pretrained(self.git_path)
                    modified = True

            if force or not (self.git_path / "pytorch_model.bin").exists():
                if self.task in self.QA_TASKS:
                    model = AutoModelForQuestionAnswering.from_pretrained(
                        src_path)
                elif self.task in self.GLUE_TASKS:
                    model = AutoModelForSequenceClassification.from_pretrained(
                        src_path)
                elif self.task in self.SUMMARIZATION_TASKS:
                    model = AutoModelForSeq2SeqLM.from_pretrained(src_path)
                else:
                    raise Exception(f"Unknown task {self.task}")
                model.save_pretrained(self.git_path)
                modified = True

            src_path = Path(src_path)
            to_copy = self.get_copy_list()

            for files, dest in to_copy:
                dest.mkdir(exist_ok=True)
                for file in files:
                    if force or not (dest / file).exists():
                        shutil.copyfile(str(src_path / file), str(dest / file))
                        modified = True
        finally:
            if d is not None:
                d.cleanup()

        # Reload the config, this may have been changed by compilation / optimization (pruned_heads, gelu_patch, layer_norm_patch)
        with (self.git_path / "config.json").open() as f:
            self.checkpoint_info["config"] = json.load(f)

        return modified