Пример #1
0
def test_transformer_gpu_trainable_and_decodable(model_dict):
    # make args
    model_args = make_transformer_args(**model_dict)
    inference_args = make_inference_args()

    idim = 40
    odim = 40
    ilens = [10, 5, 10, 5]
    olens = [20, 15, 20, 15]
    device = torch.device("cuda")
    batch = prepare_inputs(idim,
                           odim,
                           ilens,
                           olens,
                           model_args["spk_embed_dim"],
                           device=device)

    # define model
    model = Transformer(idim, odim, Namespace(**model_args))
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters())

    # trainable
    loss = model(**batch).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # check gradient of ScaledPositionalEncoding
    if model.use_scaled_pos_enc:
        assert model.encoder.embed[-1].alpha.grad is not None
        assert model.decoder.embed[-1].alpha.grad is not None

    # decodable
    model.eval()
    with torch.no_grad():
        if model_args["spk_embed_dim"] is None:
            spemb = None
        else:
            spemb = batch["spembs"][0]
        model.inference(
            batch["xs"][0][:batch["ilens"][0]],
            Namespace(**inference_args),
            spemb=spemb,
        )
        model.calculate_all_attentions(**batch)
Пример #2
0
def test_forward_and_inference_are_equal(model_dict):
    # make args
    model_args = make_transformer_args(dprenet_dropout_rate=0.0, **model_dict)

    # setup batch
    idim = 40
    odim = 40
    ilens = [60]
    olens = [60]
    batch = prepare_inputs(idim, odim, ilens, olens)
    xs = batch["xs"]
    ilens = batch["ilens"]
    ys = batch["ys"]
    olens = batch["olens"]

    # define model
    model = Transformer(idim, odim, Namespace(**model_args))
    model.eval()

    # TODO(kan-bayashi): update following ugly part
    with torch.no_grad():
        # --------- forward calculation ---------
        x_masks = model._source_mask(ilens)
        hs_fp, h_masks = model.encoder(xs, x_masks)
        if model.reduction_factor > 1:
            ys_in = ys[:, model.reduction_factor - 1::model.reduction_factor]
            olens_in = olens.new(
                [olen // model.reduction_factor for olen in olens])
        else:
            ys_in, olens_in = ys, olens
        ys_in = model._add_first_frame_and_remove_last_frame(ys_in)
        y_masks = model._target_mask(olens_in)
        zs, _ = model.decoder(ys_in, y_masks, hs_fp, h_masks)
        before_outs = model.feat_out(zs).view(zs.size(0), -1, model.odim)
        logits = model.prob_out(zs).view(zs.size(0), -1)
        after_outs = before_outs + model.postnet(before_outs.transpose(
            1, 2)).transpose(1, 2)
        # --------- forward calculation ---------

        # --------- inference calculation ---------
        hs_ir, _ = model.encoder(xs, None)
        maxlen = ys_in.shape[1]
        minlen = ys_in.shape[1]
        idx = 0
        # this is the inferene calculation but we use groundtruth to check the behavior
        ys_in_ = ys_in[0, idx].view(1, 1, model.odim)
        np.testing.assert_array_equal(
            ys_in_.new_zeros(1, 1, model.odim).detach().cpu().numpy(),
            ys_in_.detach().cpu().numpy(),
        )
        outs, probs = [], []
        while True:
            idx += 1
            y_masks = subsequent_mask(idx).unsqueeze(0)
            z = model.decoder.forward_one_step(ys_in_, y_masks,
                                               hs_ir)[0]  # (B, idx, adim)
            outs += [model.feat_out(z).view(1, -1,
                                            model.odim)]  # [(1, r, odim), ...]
            probs += [torch.sigmoid(model.prob_out(z))[0]]  # [(r), ...]
            if idx >= maxlen:
                if idx < minlen:
                    continue
                outs = torch.cat(outs, dim=1).transpose(
                    1, 2)  # (1, L, odim) -> (1, odim, L)
                if model.postnet is not None:
                    outs = outs + model.postnet(outs)  # (1, odim, L)
                outs = outs.transpose(2, 1).squeeze(0)  # (L, odim)
                probs = torch.cat(probs, dim=0)
                break
            ys_in_ = torch.cat((ys_in_, ys_in[0, idx].view(1, 1, model.odim)),
                               dim=1)  # (1, idx + 1, odim)
        # --------- inference calculation ---------

        # check both are equal
        np.testing.assert_array_almost_equal(
            hs_fp.detach().cpu().numpy(),
            hs_ir.detach().cpu().numpy(),
        )
        np.testing.assert_array_almost_equal(
            after_outs.squeeze(0).detach().cpu().numpy(),
            outs.detach().cpu().numpy(),
        )
        np.testing.assert_array_almost_equal(
            torch.sigmoid(logits.squeeze(0)).detach().cpu().numpy(),
            probs.detach().cpu().numpy(),
        )