def sample_data(): path = get_default_cache_folder() download_url( "https://github.com/scart97/lapsbm-backup/archive/refs/tags/lapsbm-ci.tar.gz", download_folder=path, resume=True, ) extract_archive(path / "lapsbm-backup-lapsbm-ci.tar.gz", path) return path / "lapsbm-backup-lapsbm-ci"
def download_checkpoint(name: NemoCheckpoint, checkpoint_folder: str = None) -> Path: """Download quartznet checkpoint by identifier. Args: name: Model identifier. Check checkpoint_archives.keys() checkpoint_folder: Folder where the checkpoint will be saved to. Returns: Path to the saved checkpoint file. """ if checkpoint_folder is None: checkpoint_folder = get_default_cache_folder() url = name.value filename = url.split("/")[-1] checkpoint_path = Path(checkpoint_folder) / filename if not checkpoint_path.exists(): wget.download(url, out=str(checkpoint_path)) return checkpoint_path
def test_expected_prediction_from_pretrained_model(): # Loading the sample file try: folder = get_default_cache_folder() download_url( "https://github.com/fastaudio/10_Speakers_Sample/raw/76f365de2f4d282ec44450d68f5b88de37b8b7ad/train/f0001_us_f0001_00001.wav", download_folder=str(folder), filename="f0001_us_f0001_00001.wav", resume=True, ) # Preparing data and model module = QuartznetModule.load_from_nemo( checkpoint_name=NemoCheckpoint.QuartzNet5x5LS_En) audio, sr = torchaudio.load(folder / "f0001_us_f0001_00001.wav") assert sr == 16000 output = module.predict(audio) expected = "the world needs opportunities for new leaders and new ideas" assert output[0].strip() == expected except HTTPError: return
def test_get_default_cache_folder(): path = get_default_cache_folder() assert path.exists()