def test_script_module(): module = QuartznetModule(list(ascii_lowercase)) module_script = torch.jit.script(module) x = torch.randn(10, 1337) out1 = module.predict(x)[0] out2 = module_script.predict(x)[0] assert out1 == out2
def test_change_vocab(): module = QuartznetModule(list(ascii_lowercase)) module.change_vocab(["a", "b", "c"]) assert module.hparams.initial_vocab_tokens == ["a", "b", "c"] # comparing to 10 to account for the 3 initial tokens plus # the few special tokens automatically added. assert len(module.text_pipeline.vocab) < 10 assert module.decoder.out_channels < 10
def test_dev_run_train(sample_manifest): module = QuartznetModule(list(ascii_lowercase)) data = ManifestDatamodule( train_manifest=sample_manifest, val_manifest=sample_manifest, test_manifest=sample_manifest, num_workers=0, ) trainer = pl.Trainer( fast_dev_run=True, logger=None, checkpoint_callback=None, gpus=-1 ) trainer.fit(module, datamodule=data)
def test_expected_prediction_from_pretrained_model(): # Loading the sample file try: folder = get_default_cache_folder() download_url( "https://github.com/fastaudio/10_Speakers_Sample/raw/76f365de2f4d282ec44450d68f5b88de37b8b7ad/train/f0001_us_f0001_00001.wav", download_folder=str(folder), filename="f0001_us_f0001_00001.wav", resume=True, ) # Preparing data and model module = QuartznetModule.load_from_nemo( checkpoint_name=NemoCheckpoint.QuartzNet5x5LS_En) audio, sr = torchaudio.load(folder / "f0001_us_f0001_00001.wav") assert sr == 16000 output = module.predict(audio) expected = "the world needs opportunities for new leaders and new ideas" assert output[0].strip() == expected except HTTPError: return
lr_unfrozen=4e-3, betas=[0.8, 0.5]) wandb.init(config=hyperparameter_defaults, project="thunder-speech", entity="madeupmasters") config = wandb.config full_dm = ManifestDatamodule( train_manifest="train/data/manifests/prepared_train_manifest.json", val_manifest="train/data/manifests/prepared_test_manifest.json", test_manifest="train/data/manifests/prepared_test_manifest.json", num_workers=24, bs=config.bs, ) model = QuartznetModule.load_from_nemo( checkpoint_name="QuartzNet15x5Base-En", ) labels = [ " ", "ɑ", "d", "m", "ɛ", "v", "ʒ", "ð", "ʃ", "θ", "g", "i", "ŋ",
def test_try_to_load_without_parameters_raises_error(): with pytest.raises(ValueError): QuartznetModule.load_from_nemo()