def test_deepspeech_train(): # do not call home to wandb os.environ['WANDB_MODE'] = 'dryrun' # hyperparams p = model.HParams(graphemes=datasets.YESNO_GRAPHEMES) tp = train.HParams(max_epochs=1, batch_size=8) # build m = model.DeepSpeech(p) trainset, testset = datasets.splits(datasets.YesNo(p), p) # train t = train.Trainer(m, trainset, testset, tp) t.train()
def load(run_path): "Load config and model from wandb" p, ptrain = utils.load_wandb_cfg(run_path) p, ptrain = model.HParams(**p), train.HParams(**ptrain) return utils.load_chkpt(model.DeepSpeech(p), run_path), ptrain
def test_hparams_nsteps_batch_too_large(): trainset_size = 80 tp = train.HParams(batch_size=80, max_epochs=10) assert tp.n_steps(trainset_size) == (80 / 80) * 10
def test_hparams_nsteps(): trainset_size = 80 tp = train.HParams(batch_size=2, max_epochs=10) assert tp.n_steps(trainset_size) == (80 / 2) * 10
def test_hparams_nsteps_last_batch_small(): trainset_size = 48 tp = train.HParams(batch_size=40, max_epochs=4) assert tp.n_steps(trainset_size) == 8
def test_onecycle(): cfg = train.HParams(batch_size=1, max_epochs=1) m = model.DeepSpeech(model.HParams()) optimizer = torch.optim.SGD(m.parameters(), lr=cfg.learning_rate) schedule = utils.onecycle(optimizer, 9, cfg) assert schedule.total_steps == 9
def test_lrfinder(): m = model.DeepSpeech(model.HParams()) optimizer = torch.optim.SGD(m.parameters(), lr=1e-8) p = train.HParams(batch_size=1, max_epochs=1) schedule = utils.lrfinder(optimizer, 9, p) assert torch.isclose(torch.tensor(schedule.gamma), torch.tensor(10.))