def prepare(E2E, args, mtlalpha=0.0): args.mtlalpha = mtlalpha args.char_list = ["a", "e", "i", "o", "u"] idim = 40 odim = 5 model = dynamic_import_asr(E2E, "pytorch")(idim, odim, args) batchsize = 5 x = torch.randn(batchsize, 40, idim) ilens = [40, 30, 20, 15, 10] n_token = odim - 1 # avoid 0 for eps in ctc y = (torch.rand(batchsize, 10) * n_token % (n_token - 1)).long() + 1 olens = [3, 9, 10, 2, 3] for i in range(batchsize): x[i, ilens[i] :] = -1 y[i, olens[i] :] = -1 data = [] for i in range(batchsize): data.append( ( "utt%d" % i, { "input": [{"shape": [ilens[i], idim]}], "output": [{"shape": [olens[i]]}], }, ) ) return model, x, torch.tensor(ilens), y, data, args
def test_train_pytorch_dtype(dtype, device, model, conf): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("no cuda device is available") if device == "cpu" and dtype == "float16": pytest.skip( "cpu float16 implementation is not available in pytorch yet") if device == "cpu" and "trans_type" in conf and conf[ "trans_type"] == "warp-rnnt": pytest.skip("warp-rnnt is not supported in CPU mode") idim = 10 odim = 10 model = dynamic_import_asr(model, "pytorch").build(idim, odim, **conf) dtype = getattr(torch, dtype) device = torch.device(device) model.to(dtype=dtype, device=device) x = torch.rand(2, 10, idim, dtype=dtype, device=device) ilens = torch.tensor([10, 7], device=device) y = torch.randint(1, odim, (2, 3), device=device) opt = torch.optim.Adam(model.parameters()) loss = model(x, ilens, y) assert loss.dtype == dtype model.zero_grad() loss.backward() assert any(p.grad is not None for p in model.parameters()) opt.step()
def test_asr_build(name, backend): model = dynamic_import_asr(name, backend).build(10, 10, mtlalpha=0.123, adim=4, eunits=3, dunits=3, elayers=2, dlayers=2) assert model.mtlalpha == 0.123
def test_asr_quantize(name, backend): model = dynamic_import_asr(name, backend).build(10, 10, mtlalpha=0.123, adim=4, eunits=2, dunits=2, elayers=1, dlayers=1) quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) assert quantized_model.state_dict()