def test_attention_masking(model_dict):
    # make args
    model_args = make_transformer_args(**model_dict)

    # setup batch
    idim = 5
    odim = 10
    ilens = [10, 5]
    olens = [20, 15]
    batch = prepare_inputs(idim, odim, ilens, olens)

    # define model
    model = Transformer(idim, odim, Namespace(**model_args))

    # test encoder self-attention
    xs = model.encoder.embed(batch["xs"])
    xs[1, ilens[1]:] = float("nan")
    x_masks = model._source_mask(batch["ilens"])
    a = model.encoder.encoders[0].self_attn
    a(xs, xs, xs, x_masks)
    aws = a.attn.detach().numpy()
    assert not np.isnan(aws).any()
    for aw, ilen in zip(aws, batch["ilens"]):
        np.testing.assert_almost_equal(aw[:, :ilen, :ilen].sum(),
                                       float(aw.shape[0] * ilen),
                                       decimal=4)
        assert aw[:, ilen:, ilen:].sum() == 0.0

    # test encoder-decoder attention
    ys = model.decoder.embed(batch["ys"])
    ys[1, olens[1]:] = float("nan")
    xy_masks = model._source_to_target_mask(batch["ilens"], batch["olens"])
    a = model.decoder.decoders[0].src_attn
    a(ys, xs, xs, xy_masks)
    aws = a.attn.detach().numpy()
    assert not np.isnan(aws).any()
    for aw, ilen, olen in zip(aws, batch["ilens"], batch["olens"]):
        np.testing.assert_almost_equal(aw[:, :olen, :ilen].sum(),
                                       float(aw.shape[0] * olen),
                                       decimal=4)
        assert aw[:, olen:, ilen:].sum() == 0.0

    # test decoder self-attention
    ys = model.decoder.embed(batch["ys"])
    ys[1, olens[1]:] = float("nan")
    y_masks = model._target_mask(batch["olens"])
    a = model.decoder.decoders[0].self_attn
    a(ys, ys, ys, y_masks)
    aws = a.attn.detach().numpy()
    assert not np.isnan(aws).any()
    for aw, olen in zip(aws, batch["olens"]):
        np.testing.assert_almost_equal(aw[:, :olen, :olen].sum(),
                                       float(aw.shape[0] * olen),
                                       decimal=4)
        assert aw[:, olen:, olen:].sum() == 0.0
Пример #2
0
def test_forward_and_inference_are_equal(model_dict):
    # make args
    model_args = make_transformer_args(dprenet_dropout_rate=0.0, **model_dict)

    # setup batch
    idim = 5
    odim = 10
    ilens = [10]
    olens = [20]
    batch = prepare_inputs(idim, odim, ilens, olens)
    xs = batch["xs"]
    ilens = batch["ilens"]
    ys = batch["ys"]
    olens = batch["olens"]

    # define model
    model = Transformer(idim, odim, Namespace(**model_args))
    model.eval()

    # TODO(kan-bayashi): update following ugly part
    with torch.no_grad():
        # --------- forward calculation ---------
        x_masks = model._source_mask(ilens)
        hs_fp, h_masks = model.encoder(xs, x_masks)
        if model.reduction_factor > 1:
            ys_in = ys[:, model.reduction_factor - 1::model.reduction_factor]
            olens_in = olens.new(
                [olen // model.reduction_factor for olen in olens])
        else:
            ys_in, olens_in = ys, olens
        ys_in = model._add_first_frame_and_remove_last_frame(ys_in)
        y_masks = model._target_mask(olens_in)
        zs, _ = model.decoder(ys_in, y_masks, hs_fp, h_masks)
        before_outs = model.feat_out(zs).view(zs.size(0), -1, model.odim)
        logits = model.prob_out(zs).view(zs.size(0), -1)
        after_outs = before_outs + model.postnet(before_outs.transpose(
            1, 2)).transpose(1, 2)
        # --------- forward calculation ---------

        # --------- inference calculation ---------
        hs_ir, _ = model.encoder(xs, None)
        maxlen = ys_in.shape[1]
        minlen = ys_in.shape[1]
        idx = 0
        # this is the inferene calculation but we use groundtruth to check the behavior
        ys_in_ = ys_in[0, idx].view(1, 1, model.odim)
        np.testing.assert_array_equal(
            ys_in_.new_zeros(1, 1, model.odim).detach().cpu().numpy(),
            ys_in_.detach().cpu().numpy(),
        )
        outs, probs = [], []
        while True:
            idx += 1
            y_masks = subsequent_mask(idx).unsqueeze(0)
            z = model.decoder.forward_one_step(ys_in_, y_masks,
                                               hs_ir)[0]  # (B, idx, adim)
            outs += [model.feat_out(z).view(1, -1,
                                            model.odim)]  # [(1, r, odim), ...]
            probs += [torch.sigmoid(model.prob_out(z))[0]]  # [(r), ...]
            if idx >= maxlen:
                if idx < minlen:
                    continue
                outs = torch.cat(outs, dim=1).transpose(
                    1, 2)  # (1, L, odim) -> (1, odim, L)
                if model.postnet is not None:
                    outs = outs + model.postnet(outs)  # (1, odim, L)
                outs = outs.transpose(2, 1).squeeze(0)  # (L, odim)
                probs = torch.cat(probs, dim=0)
                break
            ys_in_ = torch.cat((ys_in_, ys_in[0, idx].view(1, 1, model.odim)),
                               dim=1)  # (1, idx + 1, odim)
        # --------- inference calculation ---------

        # check both are equal
        np.testing.assert_array_almost_equal(
            hs_fp.detach().cpu().numpy(),
            hs_ir.detach().cpu().numpy(),
        )
        np.testing.assert_array_almost_equal(
            after_outs.squeeze(0).detach().cpu().numpy(),
            outs.detach().cpu().numpy(),
        )
        np.testing.assert_array_almost_equal(
            torch.sigmoid(logits.squeeze(0)).detach().cpu().numpy(),
            probs.detach().cpu().numpy(),
        )