Beispiel #1
0
def test_spec_augmented_dataset():
    p = model.HParams()
    data = zip([torch.rand(p.sampling_rate) for x in range(3)],
               ['yes', 'no', 'yes'])

    d = datasets.SpecAugmented(dict(data), p, masked=True)
    assert len(d) == 3
Beispiel #2
0
def test_splitting():
    p = model.HParams(splits=[0.6, 0.4])
    data = zip([torch.rand(p.sampling_rate) for x in range(3)],
               ['yes', 'no', 'yes'])

    subsets = datasets.splits(list(data), p)
    assert len(subsets) == 2
Beispiel #3
0
def test_batch_collation():
    p = model.HParams()
    data = zip([torch.rand(p.n_mels, p.sampling_rate) for x in range(3)],
               ['yes', 'no', 'yes'])

    x, nx, y, ny = datasets.batch(p)(data)

    assert x.shape == (3, p.n_mels, p.sampling_rate)
    assert y.shape == (3, 3)
    assert y[1, -1] == 0
    assert nx.equal(torch.tensor([p.sampling_rate] * 3))
    assert ny.equal(torch.tensor([3, 2, 3]))
Beispiel #4
0
def test_hparams_graphemes_idx():
    p = model.HParams(graphemes=datasets.YESNO_GRAPHEMES)
    assert p.graphemes_idx() == {
        'ε': 0,
        'e': 1,
        'k': 2,
        'l': 3,
        'n': 4,
        'o': 5,
        'r': 6,
        ' ': 7,
    }
Beispiel #5
0
def test_decode_argmax():
    batch_size = 3
    p = model.HParams(mixed_precision=False)
    tfm = datasets.transform(p)
    m = model.DeepSpeech(p)
    data = zip([tfm(torch.rand(p.sampling_rate)) for x in range(batch_size)],
               ['yes', 'no', 'yes'])

    x, xn, y, yn = datasets.batch(p)(data)
    yhat, _ = predict.predict(m, x, xn)
    decoded = predict.decode_argmax(yhat, p)  # make sure we are decodable
    assert len(decoded) == batch_size
Beispiel #6
0
def test_deepspeech_train():

    # do not call home to wandb
    os.environ['WANDB_MODE'] = 'dryrun'

    # hyperparams
    p = model.HParams(graphemes=datasets.YESNO_GRAPHEMES)
    tp = train.HParams(max_epochs=1, batch_size=8)

    # build
    m = model.DeepSpeech(p)
    trainset, testset = datasets.splits(datasets.YesNo(p), p)

    # train
    t = train.Trainer(m, trainset, testset, tp)
    t.train()
Beispiel #7
0
def test_deepspeech_fwd_augmented():
    batch_size = 5
    p = model.HParams()
    transform = datasets.transform(p)
    augment = datasets.spec_augment(p)
    m = model.DeepSpeech(p)

    # follow the same order as in data loader and trainer
    x = [augment(transform(torch.rand(p.sampling_rate))) for x in range(batch_size)]
    y = np.random.choice(['yes', 'no'], batch_size)
    x, nx, y, ny = datasets.batch(p)(zip(x, y))
    x, _ = m.forward(x, nx, y, ny)

    assert x.shape == (
        batch_size,
        49,
        p.n_graphemes()
    )
Beispiel #8
0
def test_deepspeech_fwd():
    batch_size = 5
    p = model.HParams()
    transform = datasets.transform(p)
    m = model.DeepSpeech(p)

    # follow the same order as in data loader and trainer
    x = [transform(torch.rand(p.sampling_rate)) for x in range(batch_size)]
    y = np.random.choice(['yes', 'no'], batch_size)
    x, nx, y, ny = datasets.batch(p)(zip(x, y))

    "need: (B, H, W) batches of melspecs, (B, W) batches of graphemes."
    print(x.shape, nx, y.shape, ny)

    x, _ = m.forward(x, nx, y, ny)

    assert x.shape == (
        batch_size,
        49,
        p.n_graphemes()
    )
Beispiel #9
0
def test_deepspeech_modules_registered():
    m = model.DeepSpeech(model.HParams(n_layers=1, dilation_stacks=1))
    got = list(m.state_dict().keys())
    want = [
        'conv.weight',
        'conv.bias',
        'dense_a.weight',
        'dense_a.bias',
        'dense_b.weight',
        'dense_b.bias',
        'gru.weight_ih_l0',
        'gru.weight_hh_l0',
        'gru.bias_ih_l0',
        'gru.bias_hh_l0',
        'gru.weight_ih_l0_reverse',
        'gru.weight_hh_l0_reverse',
        'gru.bias_ih_l0_reverse',
        'gru.bias_hh_l0_reverse',
        'dense_end.weight',
        'dense_end.bias'
    ]

    assert got == want
Beispiel #10
0
def test_ctc_collapse_batch():
    p = model.HParams(graphemes=np.array(['ε', 'a', 'b']))
    assert predict.ctc_collapse_batch(['aababb'], p) == ['abab']
    assert predict.ctc_collapse_batch(['aababεbε'], p) == ['ababb']
Beispiel #11
0
def test_onecycle():
    cfg = train.HParams(batch_size=1, max_epochs=1)
    m = model.DeepSpeech(model.HParams())
    optimizer = torch.optim.SGD(m.parameters(), lr=cfg.learning_rate)
    schedule = utils.onecycle(optimizer, 9, cfg)
    assert schedule.total_steps == 9
Beispiel #12
0
def test_hparams_blank():
    p = model.HParams()
    assert p.graphemes[0] == 'ε'
Beispiel #13
0
def test_hparams_override():
    p = model.HParams(use_mixed_precision=False)
    assert p.use_mixed_precision is False
Beispiel #14
0
def test_encode_texts():
    got = utils.encode_texts(['aa', 'bb'], model.HParams().graphemes_idx())
    want = [torch.tensor([1, 1]), torch.tensor([2, 2])]
    assert torch.stack(got).equal(torch.stack(want))
Beispiel #15
0
def test_hparams():
    p = model.HParams()
    assert p.sampling_rate == 8000
    assert p.n_fft() == 160
    assert p.n_downsampled_frames(101) == 51
Beispiel #16
0
def test_decode_text():
    encoded = [4, 5, 6, 9, 14, 9, 20, 5, 12, 25, 27, 13, 1, 25, 2]
    assert utils.decode_text(encoded, model.HParams()) == 'definitely mayb'
Beispiel #17
0
def test_lrfinder():
    m = model.DeepSpeech(model.HParams())
    optimizer = torch.optim.SGD(m.parameters(), lr=1e-8)
    p = train.HParams(batch_size=1, max_epochs=1)
    schedule = utils.lrfinder(optimizer, 9, p)
    assert torch.isclose(torch.tensor(schedule.gamma), torch.tensor(10.))
Beispiel #18
0
def test_hparams():
    p = model.HParams()
    assert not dict(p).get('graphemes_idx', None)
Beispiel #19
0
def test_encode_text():
    got = utils.encode_text('definitely mayb', model.HParams().graphemes_idx())
    want = [4, 5, 6, 9, 14, 9, 20, 5, 12, 25, 27, 13, 1, 25, 2]
    assert got.tolist() == want
Beispiel #20
0
def test_decode_texts():
    got = utils.decode_texts([[1, 2], [2, 1]], model.HParams())
    assert got == ['ab', 'ba']
Beispiel #21
0
def load(run_path):
    "Load config and model from wandb"
    p, ptrain = utils.load_wandb_cfg(run_path)
    p, ptrain = model.HParams(**p), train.HParams(**ptrain)
    return utils.load_chkpt(model.DeepSpeech(p), run_path), ptrain
Beispiel #22
0
def test_greedy():
    p = model.HParams(graphemes=np.array(['ε', 'a', 'b']))
    xs = torch.tensor([[0.5, 0.4, 0.1], [0.4, 0.5, 0.1], [0.3, 0.2, 0.6]])
    xs = xs.unsqueeze(0)  # B, W, C
    got = predict.decode_argmax(xs, p)
    assert got == ['ab']