コード例 #1
0
    def convert_models(self, tatoeba_ids, dry_run=False):
        entries_to_convert = [x for x in self.registry if x[0] in tatoeba_ids]
        converted_paths = convert_all_sentencepiece_models(entries_to_convert, dest_dir=self.model_card_dir)

        for path in converted_paths:
            long_pair = remove_prefix(path.name, "opus-mt-").split("-")  # eg. heb-eng
            assert len(long_pair) == 2
            new_p_src = self.get_two_letter_code(long_pair[0])
            new_p_tgt = self.get_two_letter_code(long_pair[1])
            hf_model_id = f"opus-mt-{new_p_src}-{new_p_tgt}"
            new_path = path.parent.joinpath(hf_model_id)  # opus-mt-he-en
            os.rename(str(path), str(new_path))
            self.write_model_card(hf_model_id, dry_run=dry_run)
    def __init__(self, save_dir="marian_converted"):
        assert Path(DEFAULT_REPO).exists(
        ), "need git clone [email protected]:Helsinki-NLP/Tatoeba-Challenge.git"
        reg = self.make_tatoeba_registry()
        self.download_metadata()
        self.registry = reg
        reg_df = pd.DataFrame(
            reg, columns=["id", "prepro", "url_model", "url_test_set"])
        assert reg_df.id.value_counts().max() == 1
        reg_df = reg_df.set_index("id")
        reg_df["src"] = reg_df.reset_index().id.apply(
            lambda x: x.split("-")[0]).values
        reg_df["tgt"] = reg_df.reset_index().id.apply(
            lambda x: x.split("-")[1]).values

        released_cols = [
            "url_base",
            "pair",  # (ISO639-3/ISO639-5 codes),
            "short_pair",  # (reduced codes),
            "chrF2_score",
            "bleu",
            "brevity_penalty",
            "ref_len",
            "src_name",
            "tgt_name",
        ]

        released = pd.read_csv("Tatoeba-Challenge/models/released-models.txt",
                               sep="\t",
                               header=None).iloc[:-1]
        released.columns = released_cols
        released["fname"] = released["url_base"].apply(lambda x: remove_suffix(
            remove_prefix(
                x, "https://object.pouta.csc.fi/Tatoeba-Challenge/opus"),
            ".zip"))

        released["2m"] = released.fname.str.startswith("2m")
        released["date"] = pd.to_datetime(released["fname"].apply(
            lambda x: remove_prefix(remove_prefix(x, "2m-"), "-")))

        released["base_ext"] = released.url_base.apply(lambda x: Path(x).name)
        reg_df["base_ext"] = reg_df.url_model.apply(lambda x: Path(x).name)

        metadata_new = reg_df.reset_index().merge(
            released.rename(columns={"pair": "id"}), on=["base_ext", "id"])

        metadata_renamer = {
            "src": "src_alpha3",
            "tgt": "tgt_alpha3",
            "id": "long_pair",
            "date": "train_date"
        }
        metadata_new = metadata_new.rename(columns=metadata_renamer)

        metadata_new["src_alpha2"] = metadata_new.short_pair.apply(
            lambda x: x.split("-")[0])
        metadata_new["tgt_alpha2"] = metadata_new.short_pair.apply(
            lambda x: x.split("-")[1])
        DROP_COLS_BOTH = ["url_base", "base_ext", "fname"]

        metadata_new = metadata_new.drop(DROP_COLS_BOTH, 1)
        metadata_new["prefer_old"] = metadata_new.long_pair.isin([])
        self.metadata = metadata_new
        assert self.metadata.short_pair.value_counts().max(
        ) == 1, "Multiple metadata entries for a short pair"
        self.metadata = self.metadata.set_index("short_pair")

        # wget.download(LANG_CODE_URL)
        mapper = pd.read_csv(LANG_CODE_PATH)
        mapper.columns = ["a3", "a2", "ref"]
        self.iso_table = pd.read_csv(
            ISO_PATH, sep="\t").rename(columns=lambda x: x.lower())
        more_3_to_2 = self.iso_table.set_index("id").part1.dropna().to_dict()
        more_3_to_2.update(mapper.set_index("a3").a2.to_dict())
        self.alpha3_to_alpha2 = more_3_to_2
        self.model_card_dir = Path(save_dir)
        self.constituents = GROUP_MEMBERS
    def write_model_card(
        self,
        hf_model_id: str,
        repo_root=DEFAULT_REPO,
        dry_run=False,
    ) -> str:
        """Copy the most recent model's readme section from opus, and add metadata.
        upload command: aws s3 sync model_card_dir s3://models.huggingface.co/bert/Helsinki-NLP/ --dryrun
        """
        short_pair = remove_prefix(hf_model_id, "opus-mt-")
        extra_metadata = self.metadata.loc[short_pair].drop("2m")
        extra_metadata["short_pair"] = short_pair
        lang_tags, src_multilingual, tgt_multilingual = self.resolve_lang_code(
            extra_metadata)
        opus_name = f"{extra_metadata.src_alpha3}-{extra_metadata.tgt_alpha3}"
        # opus_name: str = self.convert_hf_name_to_opus_name(hf_model_name)

        assert repo_root in ("OPUS-MT-train", "Tatoeba-Challenge")
        opus_readme_path = Path(repo_root).joinpath("models", opus_name,
                                                    "README.md")
        assert opus_readme_path.exists(
        ), f"Readme file {opus_readme_path} not found"

        opus_src, opus_tgt = [x.split("+") for x in opus_name.split("-")]

        readme_url = f"https://github.com/Helsinki-NLP/{repo_root}/tree/master/models/{opus_name}/README.md"

        s, t = ",".join(opus_src), ",".join(opus_tgt)

        metadata = {
            "hf_name": short_pair,
            "source_languages": s,
            "target_languages": t,
            "opus_readme_url": readme_url,
            "original_repo": repo_root,
            "tags": ["translation"],
            "languages": lang_tags,
        }
        lang_tags = l2front_matter(lang_tags)
        metadata["src_constituents"] = self.constituents[s]
        metadata["tgt_constituents"] = self.constituents[t]
        metadata["src_multilingual"] = src_multilingual
        metadata["tgt_multilingual"] = tgt_multilingual

        metadata.update(extra_metadata)
        metadata.update(get_system_metadata(repo_root))

        # combine with Tatoeba markdown

        extra_markdown = f"### {short_pair}\n\n* source group: {metadata['src_name']} \n* target group: {metadata['tgt_name']} \n*  OPUS readme: [{opus_name}]({readme_url})\n"

        content = opus_readme_path.open().read()
        content = content.split(
            "\n# "
        )[-1]  # Get the lowest level 1 header in the README -- the most recent model.
        splat = content.split("*")[2:]

        content = "*".join(splat)
        # BETTER FRONT MATTER LOGIC

        content = (FRONT_MATTER_TEMPLATE.format(lang_tags) + extra_markdown +
                   "\n* " + content.replace("download", "download original "
                                            "weights"))

        items = "\n\n".join([f"- {k}: {v}" for k, v in metadata.items()])
        sec3 = "\n### System Info: \n" + items
        content += sec3
        if dry_run:
            return content, metadata
        sub_dir = self.model_card_dir / hf_model_id
        sub_dir.mkdir(exist_ok=True)
        dest = sub_dir / "README.md"
        dest.open("w").write(content)
        pd.Series(metadata).to_json(sub_dir / "metadata.json")
        return content, metadata