コード例 #1
0
def sample_data():
    path = get_default_cache_folder()
    download_url(
        "https://github.com/scart97/lapsbm-backup/archive/refs/tags/lapsbm-ci.tar.gz",
        download_folder=path,
        resume=True,
    )
    extract_archive(path / "lapsbm-backup-lapsbm-ci.tar.gz", path)
    return path / "lapsbm-backup-lapsbm-ci"
コード例 #2
0
def download_checkpoint(name: NemoCheckpoint, checkpoint_folder: str = None) -> Path:
    """Download quartznet checkpoint by identifier.

    Args:
        name: Model identifier. Check checkpoint_archives.keys()
        checkpoint_folder: Folder where the checkpoint will be saved to.

    Returns:
        Path to the saved checkpoint file.
    """
    if checkpoint_folder is None:
        checkpoint_folder = get_default_cache_folder()

    url = name.value
    filename = url.split("/")[-1]
    checkpoint_path = Path(checkpoint_folder) / filename
    if not checkpoint_path.exists():
        wget.download(url, out=str(checkpoint_path))

    return checkpoint_path
コード例 #3
0
def test_expected_prediction_from_pretrained_model():
    # Loading the sample file
    try:
        folder = get_default_cache_folder()
        download_url(
            "https://github.com/fastaudio/10_Speakers_Sample/raw/76f365de2f4d282ec44450d68f5b88de37b8b7ad/train/f0001_us_f0001_00001.wav",
            download_folder=str(folder),
            filename="f0001_us_f0001_00001.wav",
            resume=True,
        )
        # Preparing data and model
        module = QuartznetModule.load_from_nemo(
            checkpoint_name=NemoCheckpoint.QuartzNet5x5LS_En)
        audio, sr = torchaudio.load(folder / "f0001_us_f0001_00001.wav")
        assert sr == 16000

        output = module.predict(audio)
        expected = "the world needs opportunities for new leaders and new ideas"
        assert output[0].strip() == expected
    except HTTPError:
        return
コード例 #4
0
ファイル: test_utils.py プロジェクト: scart97/thunder-speech
def test_get_default_cache_folder():
    path = get_default_cache_folder()
    assert path.exists()