def test_attention_masking(model_dict): # make args model_args = make_transformer_args(**model_dict) # setup batch idim = 5 odim = 10 ilens = [10, 5] olens = [20, 15] batch = prepare_inputs(idim, odim, ilens, olens) # define model model = Transformer(idim, odim, Namespace(**model_args)) # test encoder self-attention xs = model.encoder.embed(batch["xs"]) xs[1, ilens[1]:] = float("nan") x_masks = model._source_mask(batch["ilens"]) a = model.encoder.encoders[0].self_attn a(xs, xs, xs, x_masks) aws = a.attn.detach().numpy() assert not np.isnan(aws).any() for aw, ilen in zip(aws, batch["ilens"]): np.testing.assert_almost_equal(aw[:, :ilen, :ilen].sum(), float(aw.shape[0] * ilen), decimal=4) assert aw[:, ilen:, ilen:].sum() == 0.0 # test encoder-decoder attention ys = model.decoder.embed(batch["ys"]) ys[1, olens[1]:] = float("nan") xy_masks = model._source_to_target_mask(batch["ilens"], batch["olens"]) a = model.decoder.decoders[0].src_attn a(ys, xs, xs, xy_masks) aws = a.attn.detach().numpy() assert not np.isnan(aws).any() for aw, ilen, olen in zip(aws, batch["ilens"], batch["olens"]): np.testing.assert_almost_equal(aw[:, :olen, :ilen].sum(), float(aw.shape[0] * olen), decimal=4) assert aw[:, olen:, ilen:].sum() == 0.0 # test decoder self-attention ys = model.decoder.embed(batch["ys"]) ys[1, olens[1]:] = float("nan") y_masks = model._target_mask(batch["olens"]) a = model.decoder.decoders[0].self_attn a(ys, ys, ys, y_masks) aws = a.attn.detach().numpy() assert not np.isnan(aws).any() for aw, olen in zip(aws, batch["olens"]): np.testing.assert_almost_equal(aw[:, :olen, :olen].sum(), float(aw.shape[0] * olen), decimal=4) assert aw[:, olen:, olen:].sum() == 0.0
def test_forward_and_inference_are_equal(model_dict): # make args model_args = make_transformer_args(dprenet_dropout_rate=0.0, **model_dict) # setup batch idim = 5 odim = 10 ilens = [10] olens = [20] batch = prepare_inputs(idim, odim, ilens, olens) xs = batch["xs"] ilens = batch["ilens"] ys = batch["ys"] olens = batch["olens"] # define model model = Transformer(idim, odim, Namespace(**model_args)) model.eval() # TODO(kan-bayashi): update following ugly part with torch.no_grad(): # --------- forward calculation --------- x_masks = model._source_mask(ilens) hs_fp, _ = model.encoder(xs, x_masks) if model.reduction_factor > 1: ys_in = ys[:, model.reduction_factor - 1::model.reduction_factor] olens_in = olens.new( [olen // model.reduction_factor for olen in olens]) else: ys_in, olens_in = ys, olens ys_in = model._add_first_frame_and_remove_last_frame(ys_in) y_masks = model._target_mask(olens_in) xy_masks = model._source_to_target_mask(ilens, olens_in) zs, _ = model.decoder(ys_in, y_masks, hs_fp, xy_masks) before_outs = model.feat_out(zs).view(zs.size(0), -1, model.odim) logits = model.prob_out(zs).view(zs.size(0), -1) after_outs = before_outs + model.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) # --------- forward calculation --------- # --------- inference calculation --------- hs_ir, _ = model.encoder(xs, None) maxlen = ys_in.shape[1] minlen = ys_in.shape[1] idx = 0 # this is the inferene calculation but we use groundtruth to check the behavior ys_in_ = ys_in[0, idx].view(1, 1, model.odim) np.testing.assert_array_equal( ys_in_.new_zeros(1, 1, model.odim).detach().cpu().numpy(), ys_in_.detach().cpu().numpy(), ) outs, probs = [], [] while True: idx += 1 y_masks = subsequent_mask(idx).unsqueeze(0) z = model.decoder.recognize(ys_in_, y_masks, hs_ir) # (B, idx, adim) outs += [model.feat_out(z).view(1, -1, model.odim)] # [(1, r, odim), ...] probs += [torch.sigmoid(model.prob_out(z))[0]] # [(r), ...] if idx >= maxlen: if idx < minlen: continue outs = torch.cat(outs, dim=1).transpose( 1, 2) # (1, L, odim) -> (1, odim, L) if model.postnet is not None: outs = outs + model.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) break ys_in_ = torch.cat((ys_in_, ys_in[0, idx].view(1, 1, model.odim)), dim=1) # (1, idx + 1, odim) # --------- inference calculation --------- # check both are equal np.testing.assert_array_almost_equal( hs_fp.detach().cpu().numpy(), hs_ir.detach().cpu().numpy(), ) np.testing.assert_array_almost_equal( after_outs.squeeze(0).detach().cpu().numpy(), outs.detach().cpu().numpy(), ) np.testing.assert_array_almost_equal( torch.sigmoid(logits.squeeze(0)).detach().cpu().numpy(), probs.detach().cpu().numpy(), )