Beispiel #1
0
def test_transformer_multi_gpu_trainable(model_dict):
    # make args
    model_args = make_transformer_args(**model_dict)

    # setup batch
    idim = 40
    odim = 40
    ilens = [10, 5, 10, 5]
    olens = [20, 15, 20, 15]
    device = torch.device("cuda")
    batch = prepare_inputs(idim,
                           odim,
                           ilens,
                           olens,
                           model_args["spk_embed_dim"],
                           device=device)

    # define model
    ngpu = 2
    device_ids = list(range(ngpu))
    model = Transformer(idim, odim, Namespace(**model_args))
    model = torch.nn.DataParallel(model, device_ids)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters())

    # trainable
    loss = model(**batch).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # check gradient of ScaledPositionalEncoding
    if model.module.use_scaled_pos_enc:
        assert model.module.encoder.embed[-1].alpha.grad is not None
        assert model.module.decoder.embed[-1].alpha.grad is not None
Beispiel #2
0
def test_transformer_gpu_trainable_and_decodable(model_dict):
    # make args
    model_args = make_transformer_args(**model_dict)
    inference_args = make_inference_args()

    idim = 40
    odim = 40
    ilens = [10, 5, 10, 5]
    olens = [20, 15, 20, 15]
    device = torch.device("cuda")
    batch = prepare_inputs(idim,
                           odim,
                           ilens,
                           olens,
                           model_args["spk_embed_dim"],
                           device=device)

    # define model
    model = Transformer(idim, odim, Namespace(**model_args))
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters())

    # trainable
    loss = model(**batch).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # check gradient of ScaledPositionalEncoding
    if model.use_scaled_pos_enc:
        assert model.encoder.embed[-1].alpha.grad is not None
        assert model.decoder.embed[-1].alpha.grad is not None

    # decodable
    model.eval()
    with torch.no_grad():
        if model_args["spk_embed_dim"] is None:
            spemb = None
        else:
            spemb = batch["spembs"][0]
        model.inference(
            batch["xs"][0][:batch["ilens"][0]],
            Namespace(**inference_args),
            spemb=spemb,
        )
        model.calculate_all_attentions(**batch)