Ejemplo n.º 1
0
def prepare(E2E, args, mtlalpha=0.0):
    args.mtlalpha = mtlalpha
    args.char_list = ["a", "e", "i", "o", "u"]
    idim = 40
    odim = 5
    model = dynamic_import_asr(E2E, "pytorch")(idim, odim, args)
    batchsize = 5
    x = torch.randn(batchsize, 40, idim)
    ilens = [40, 30, 20, 15, 10]
    n_token = odim - 1
    # avoid 0 for eps in ctc
    y = (torch.rand(batchsize, 10) * n_token % (n_token - 1)).long() + 1
    olens = [3, 9, 10, 2, 3]
    for i in range(batchsize):
        x[i, ilens[i] :] = -1
        y[i, olens[i] :] = -1

    data = []
    for i in range(batchsize):
        data.append(
            (
                "utt%d" % i,
                {
                    "input": [{"shape": [ilens[i], idim]}],
                    "output": [{"shape": [olens[i]]}],
                },
            )
        )
    return model, x, torch.tensor(ilens), y, data, args
Ejemplo n.º 2
0
def test_train_pytorch_dtype(dtype, device, model, conf):
    if device == "cuda" and not torch.cuda.is_available():
        pytest.skip("no cuda device is available")
    if device == "cpu" and dtype == "float16":
        pytest.skip(
            "cpu float16 implementation is not available in pytorch yet")
    if device == "cpu" and "trans_type" in conf and conf[
            "trans_type"] == "warp-rnnt":
        pytest.skip("warp-rnnt is not supported in CPU mode")

    idim = 10
    odim = 10
    model = dynamic_import_asr(model, "pytorch").build(idim, odim, **conf)
    dtype = getattr(torch, dtype)
    device = torch.device(device)
    model.to(dtype=dtype, device=device)

    x = torch.rand(2, 10, idim, dtype=dtype, device=device)
    ilens = torch.tensor([10, 7], device=device)
    y = torch.randint(1, odim, (2, 3), device=device)
    opt = torch.optim.Adam(model.parameters())
    loss = model(x, ilens, y)
    assert loss.dtype == dtype
    model.zero_grad()
    loss.backward()
    assert any(p.grad is not None for p in model.parameters())
    opt.step()
Ejemplo n.º 3
0
def test_asr_build(name, backend):
    model = dynamic_import_asr(name, backend).build(10,
                                                    10,
                                                    mtlalpha=0.123,
                                                    adim=4,
                                                    eunits=3,
                                                    dunits=3,
                                                    elayers=2,
                                                    dlayers=2)
    assert model.mtlalpha == 0.123
Ejemplo n.º 4
0
def test_asr_quantize(name, backend):
    model = dynamic_import_asr(name, backend).build(10,
                                                    10,
                                                    mtlalpha=0.123,
                                                    adim=4,
                                                    eunits=2,
                                                    dunits=2,
                                                    elayers=1,
                                                    dlayers=1)
    quantized_model = torch.quantization.quantize_dynamic(model,
                                                          {torch.nn.Linear},
                                                          dtype=torch.qint8)
    assert quantized_model.state_dict()