コード例 #1
0
def test_script_module():
    module = QuartznetModule(list(ascii_lowercase))
    module_script = torch.jit.script(module)
    x = torch.randn(10, 1337)
    out1 = module.predict(x)[0]
    out2 = module_script.predict(x)[0]
    assert out1 == out2
コード例 #2
0
def test_change_vocab():
    module = QuartznetModule(list(ascii_lowercase))
    module.change_vocab(["a", "b", "c"])
    assert module.hparams.initial_vocab_tokens == ["a", "b", "c"]
    # comparing to 10 to account for the 3 initial tokens plus
    # the few special tokens automatically added.
    assert len(module.text_pipeline.vocab) < 10
    assert module.decoder.out_channels < 10
コード例 #3
0
def test_dev_run_train(sample_manifest):
    module = QuartznetModule(list(ascii_lowercase))
    data = ManifestDatamodule(
        train_manifest=sample_manifest,
        val_manifest=sample_manifest,
        test_manifest=sample_manifest,
        num_workers=0,
    )
    trainer = pl.Trainer(
        fast_dev_run=True, logger=None, checkpoint_callback=None, gpus=-1
    )
    trainer.fit(module, datamodule=data)
コード例 #4
0
def test_expected_prediction_from_pretrained_model():
    # Loading the sample file
    try:
        folder = get_default_cache_folder()
        download_url(
            "https://github.com/fastaudio/10_Speakers_Sample/raw/76f365de2f4d282ec44450d68f5b88de37b8b7ad/train/f0001_us_f0001_00001.wav",
            download_folder=str(folder),
            filename="f0001_us_f0001_00001.wav",
            resume=True,
        )
        # Preparing data and model
        module = QuartznetModule.load_from_nemo(
            checkpoint_name=NemoCheckpoint.QuartzNet5x5LS_En)
        audio, sr = torchaudio.load(folder / "f0001_us_f0001_00001.wav")
        assert sr == 16000

        output = module.predict(audio)
        expected = "the world needs opportunities for new leaders and new ideas"
        assert output[0].strip() == expected
    except HTTPError:
        return
コード例 #5
0
                               lr_unfrozen=4e-3,
                               betas=[0.8, 0.5])

wandb.init(config=hyperparameter_defaults,
           project="thunder-speech",
           entity="madeupmasters")
config = wandb.config
full_dm = ManifestDatamodule(
    train_manifest="train/data/manifests/prepared_train_manifest.json",
    val_manifest="train/data/manifests/prepared_test_manifest.json",
    test_manifest="train/data/manifests/prepared_test_manifest.json",
    num_workers=24,
    bs=config.bs,
)

model = QuartznetModule.load_from_nemo(
    checkpoint_name="QuartzNet15x5Base-En", )

labels = [
    " ",
    "ɑ",
    "d",
    "m",
    "ɛ",
    "v",
    "ʒ",
    "ð",
    "ʃ",
    "θ",
    "g",
    "i",
    "ŋ",
コード例 #6
0
def test_try_to_load_without_parameters_raises_error():
    with pytest.raises(ValueError):
        QuartznetModule.load_from_nemo()