def test_decode_argmax(): batch_size = 3 p = model.HParams(mixed_precision=False) tfm = datasets.transform(p) m = model.DeepSpeech(p) data = zip([tfm(torch.rand(p.sampling_rate)) for x in range(batch_size)], ['yes', 'no', 'yes']) x, xn, y, yn = datasets.batch(p)(data) yhat, _ = predict.predict(m, x, xn) decoded = predict.decode_argmax(yhat, p) # make sure we are decodable assert len(decoded) == batch_size
def test_deepspeech_train(): # do not call home to wandb os.environ['WANDB_MODE'] = 'dryrun' # hyperparams p = model.HParams(graphemes=datasets.YESNO_GRAPHEMES) tp = train.HParams(max_epochs=1, batch_size=8) # build m = model.DeepSpeech(p) trainset, testset = datasets.splits(datasets.YesNo(p), p) # train t = train.Trainer(m, trainset, testset, tp) t.train()
def test_deepspeech_fwd_augmented(): batch_size = 5 p = model.HParams() transform = datasets.transform(p) augment = datasets.spec_augment(p) m = model.DeepSpeech(p) # follow the same order as in data loader and trainer x = [augment(transform(torch.rand(p.sampling_rate))) for x in range(batch_size)] y = np.random.choice(['yes', 'no'], batch_size) x, nx, y, ny = datasets.batch(p)(zip(x, y)) x, _ = m.forward(x, nx, y, ny) assert x.shape == ( batch_size, 49, p.n_graphemes() )
def test_deepspeech_fwd(): batch_size = 5 p = model.HParams() transform = datasets.transform(p) m = model.DeepSpeech(p) # follow the same order as in data loader and trainer x = [transform(torch.rand(p.sampling_rate)) for x in range(batch_size)] y = np.random.choice(['yes', 'no'], batch_size) x, nx, y, ny = datasets.batch(p)(zip(x, y)) "need: (B, H, W) batches of melspecs, (B, W) batches of graphemes." print(x.shape, nx, y.shape, ny) x, _ = m.forward(x, nx, y, ny) assert x.shape == ( batch_size, 49, p.n_graphemes() )
def test_deepspeech_modules_registered(): m = model.DeepSpeech(model.HParams(n_layers=1, dilation_stacks=1)) got = list(m.state_dict().keys()) want = [ 'conv.weight', 'conv.bias', 'dense_a.weight', 'dense_a.bias', 'dense_b.weight', 'dense_b.bias', 'gru.weight_ih_l0', 'gru.weight_hh_l0', 'gru.bias_ih_l0', 'gru.bias_hh_l0', 'gru.weight_ih_l0_reverse', 'gru.weight_hh_l0_reverse', 'gru.bias_ih_l0_reverse', 'gru.bias_hh_l0_reverse', 'dense_end.weight', 'dense_end.bias' ] assert got == want
def load(run_path): "Load config and model from wandb" p, ptrain = utils.load_wandb_cfg(run_path) p, ptrain = model.HParams(**p), train.HParams(**ptrain) return utils.load_chkpt(model.DeepSpeech(p), run_path), ptrain
def test_onecycle(): cfg = train.HParams(batch_size=1, max_epochs=1) m = model.DeepSpeech(model.HParams()) optimizer = torch.optim.SGD(m.parameters(), lr=cfg.learning_rate) schedule = utils.onecycle(optimizer, 9, cfg) assert schedule.total_steps == 9
def test_lrfinder(): m = model.DeepSpeech(model.HParams()) optimizer = torch.optim.SGD(m.parameters(), lr=1e-8) p = train.HParams(batch_size=1, max_epochs=1) schedule = utils.lrfinder(optimizer, 9, p) assert torch.isclose(torch.tensor(schedule.gamma), torch.tensor(10.))