Example #1
0
def test_model():
    d = ModelDownloader()
    tasks = ["asr", "tts"]

    for task in tasks:
        for model_name in d.query(task=task):
            if d.query("valid", name=model_name)[0] == "false":
                continue
            print(f"#### Test {model_name} ####")

            if task == "asr":
                _asr(model_name)
            elif task == "tts":
                _tts(model_name)
            else:
                raise NotImplementedError(f"task={task}")
Example #2
0
def test_model():
    d = ModelDownloader()
    tasks = ["asr", "tts"]

    for task in tasks:
        for corpus in list(set(d.query("corpus", task=task))):
            for model_name in d.query(task=task, corpus=corpus):
                if d.query("valid", name=model_name)[0] == "false":
                    continue
                print(f"#### Test {model_name} ####")

                if task == "asr":
                    _asr(model_name)
                elif task == "tts":
                    _tts(model_name)
                else:
                    raise NotImplementedError(f"task={task}")

            # NOTE(kan-bayashi): remove and recreate cache dir to reduce the disk usage.
            shutil.rmtree("downloads")
            os.makedirs("downloads")
Example #3
0
def create_Readme_file(repo_name, model_name):
    # Fill in the blanks in the template Readme eg. add task tags, model name etc.
    d = ModelDownloader()
    corpus_name = d.query("corpus", name=model_name)[0]
    task_name = d.query("task", name=model_name)[0]
    url_name = d.query("url", name=model_name)[0].split("files/")[0]
    user_name = model_name.split("/")[0]
    lang_name = d.query("lang", name=model_name)[0].replace("jp", "ja")
    template_Readme = open("TEMPLATE_Readme.md")
    new_Readme = open(repo_name + "/README.md", "w")
    lines_arr = [line for line in template_Readme]
    line_final_arr = []
    for line in lines_arr:
        if "<add_more_tags>" in line:
            if task_name == "asr":
                line = line.replace("<add_more_tags>",
                                    "automatic-speech-recognition")
            elif task_name == "tts":
                line = line.replace("<add_more_tags>", "text-to-speech")
            elif task_name == "enh":
                line = line.replace("<add_more_tags>",
                                    "speech-enhancement\n- audio-to-audio")
        if "<add_lang>" in line:
            if lang_name == "multilingual":
                line = line.replace("<add_lang>",
                                    "en\n- zh\n- ja\n- multilingual")
            else:
                line = line.replace("<add_lang>", lang_name)
        line = line.replace("<add_model_name>", model_name)
        line = line.replace("<add_url>", url_name)
        line = line.replace("<add_name>", user_name)
        line = line.replace("<add_corpus>", corpus_name)
        line = line.replace("<add_task_name>", task_name.upper())
        line = line.replace("<add_recipe_task_name>", task_name.lower() + "1")
        if "<add_tts_reference>" in line:
            if task_name == "tts":
                line = line.replace("<add_tts_reference>", tts_reference)
            else:
                line = line.replace("<add_tts_reference>", "")
        new_Readme.write(line)
def test_get_model_names_non_matching():
    d = ModelDownloader()
    assert d.query("name", task="dummy") == []
def test_get_model_names_and_urls():
    d = ModelDownloader()
    d.query(["name", "url"], task="asr")
def test_download_and_unpack_names_with_condition():
    d = ModelDownloader()
    d.query("name", task="asr")