示例#1
0
def test_decode_argmax():
    batch_size = 3
    p = model.HParams(mixed_precision=False)
    tfm = datasets.transform(p)
    m = model.DeepSpeech(p)
    data = zip([tfm(torch.rand(p.sampling_rate)) for x in range(batch_size)],
               ['yes', 'no', 'yes'])

    x, xn, y, yn = datasets.batch(p)(data)
    yhat, _ = predict.predict(m, x, xn)
    decoded = predict.decode_argmax(yhat, p)  # make sure we are decodable
    assert len(decoded) == batch_size
示例#2
0
def test_deepspeech_train():

    # do not call home to wandb
    os.environ['WANDB_MODE'] = 'dryrun'

    # hyperparams
    p = model.HParams(graphemes=datasets.YESNO_GRAPHEMES)
    tp = train.HParams(max_epochs=1, batch_size=8)

    # build
    m = model.DeepSpeech(p)
    trainset, testset = datasets.splits(datasets.YesNo(p), p)

    # train
    t = train.Trainer(m, trainset, testset, tp)
    t.train()
示例#3
0
def test_deepspeech_fwd_augmented():
    batch_size = 5
    p = model.HParams()
    transform = datasets.transform(p)
    augment = datasets.spec_augment(p)
    m = model.DeepSpeech(p)

    # follow the same order as in data loader and trainer
    x = [augment(transform(torch.rand(p.sampling_rate))) for x in range(batch_size)]
    y = np.random.choice(['yes', 'no'], batch_size)
    x, nx, y, ny = datasets.batch(p)(zip(x, y))
    x, _ = m.forward(x, nx, y, ny)

    assert x.shape == (
        batch_size,
        49,
        p.n_graphemes()
    )
示例#4
0
def test_deepspeech_fwd():
    batch_size = 5
    p = model.HParams()
    transform = datasets.transform(p)
    m = model.DeepSpeech(p)

    # follow the same order as in data loader and trainer
    x = [transform(torch.rand(p.sampling_rate)) for x in range(batch_size)]
    y = np.random.choice(['yes', 'no'], batch_size)
    x, nx, y, ny = datasets.batch(p)(zip(x, y))

    "need: (B, H, W) batches of melspecs, (B, W) batches of graphemes."
    print(x.shape, nx, y.shape, ny)

    x, _ = m.forward(x, nx, y, ny)

    assert x.shape == (
        batch_size,
        49,
        p.n_graphemes()
    )
示例#5
0
def test_deepspeech_modules_registered():
    m = model.DeepSpeech(model.HParams(n_layers=1, dilation_stacks=1))
    got = list(m.state_dict().keys())
    want = [
        'conv.weight',
        'conv.bias',
        'dense_a.weight',
        'dense_a.bias',
        'dense_b.weight',
        'dense_b.bias',
        'gru.weight_ih_l0',
        'gru.weight_hh_l0',
        'gru.bias_ih_l0',
        'gru.bias_hh_l0',
        'gru.weight_ih_l0_reverse',
        'gru.weight_hh_l0_reverse',
        'gru.bias_ih_l0_reverse',
        'gru.bias_hh_l0_reverse',
        'dense_end.weight',
        'dense_end.bias'
    ]

    assert got == want
示例#6
0
def load(run_path):
    "Load config and model from wandb"
    p, ptrain = utils.load_wandb_cfg(run_path)
    p, ptrain = model.HParams(**p), train.HParams(**ptrain)
    return utils.load_chkpt(model.DeepSpeech(p), run_path), ptrain
示例#7
0
def test_onecycle():
    cfg = train.HParams(batch_size=1, max_epochs=1)
    m = model.DeepSpeech(model.HParams())
    optimizer = torch.optim.SGD(m.parameters(), lr=cfg.learning_rate)
    schedule = utils.onecycle(optimizer, 9, cfg)
    assert schedule.total_steps == 9
示例#8
0
def test_lrfinder():
    m = model.DeepSpeech(model.HParams())
    optimizer = torch.optim.SGD(m.parameters(), lr=1e-8)
    p = train.HParams(batch_size=1, max_epochs=1)
    schedule = utils.lrfinder(optimizer, 9, p)
    assert torch.isclose(torch.tensor(schedule.gamma), torch.tensor(10.))