def test_transformer_multi_gpu_trainable(model_dict): # make args model_args = make_transformer_args(**model_dict) # setup batch idim = 40 odim = 40 ilens = [10, 5, 10, 5] olens = [20, 15, 20, 15] device = torch.device("cuda") batch = prepare_inputs(idim, odim, ilens, olens, model_args["spk_embed_dim"], device=device) # define model ngpu = 2 device_ids = list(range(ngpu)) model = Transformer(idim, odim, Namespace(**model_args)) model = torch.nn.DataParallel(model, device_ids) model.to(device) optimizer = torch.optim.Adam(model.parameters()) # trainable loss = model(**batch).mean() optimizer.zero_grad() loss.backward() optimizer.step() # check gradient of ScaledPositionalEncoding if model.module.use_scaled_pos_enc: assert model.module.encoder.embed[-1].alpha.grad is not None assert model.module.decoder.embed[-1].alpha.grad is not None
def test_attention_masking(model_dict): # make args model_args = make_transformer_args(**model_dict) # setup batch idim = 40 odim = 40 ilens = [40, 40] olens = [40, 40] batch = prepare_inputs(idim, odim, ilens, olens) # define model model = Transformer(idim, odim, Namespace(**model_args)) # test encoder self-attention x_masks = model._source_mask(batch["ilens"]) xs, x_masks = model.encoder.embed(batch["xs"], x_masks) xs[1, ilens[1]:] = float("nan") a = model.encoder.encoders[0].self_attn a(xs, xs, xs, x_masks) aws = a.attn.detach().numpy() for aw, ilen in zip(aws, batch["ilens"]): ilen = floor(floor(((ilen - 1) / 2) - 1) / 2) # due to 4x down sampling assert not np.isnan(aw[:, :ilen, :ilen]).any() np.testing.assert_almost_equal( aw[:, :ilen, :ilen].sum(), float(aw.shape[0] * ilen), decimal=4, err_msg=f"ilen={ilen}, awshape={str(aw)}", ) assert aw[:, ilen:, ilen:].sum() == 0.0 # test encoder-decoder attention ys = model.decoder.embed(batch["ys"]) ys[1, olens[1]:] = float("nan") xy_masks = x_masks a = model.decoder.decoders[0].src_attn a(ys, xs, xs, xy_masks) aws = a.attn.detach().numpy() for aw, ilen, olen in zip(aws, batch["ilens"], batch["olens"]): ilen = floor(floor(((ilen - 1) / 2) - 1) / 2) # due to 4x down sampling assert not np.isnan(aw[:, :olen, :ilen]).any() np.testing.assert_almost_equal(aw[:, :olen, :ilen].sum(), float(aw.shape[0] * olen), decimal=4) assert aw[:, olen:, ilen:].sum() == 0.0 # test decoder self-attention y_masks = model._target_mask(batch["olens"]) a = model.decoder.decoders[0].self_attn a(ys, ys, ys, y_masks) aws = a.attn.detach().numpy() for aw, olen in zip(aws, batch["olens"]): assert not np.isnan(aw[:, :olen, :olen]).any() np.testing.assert_almost_equal(aw[:, :olen, :olen].sum(), float(aw.shape[0] * olen), decimal=4) assert aw[:, olen:, olen:].sum() == 0.0
def test_transformer_trainable_and_decodable(model_dict): # make args model_args = make_transformer_args(**model_dict) inference_args = make_inference_args() # setup batch idim = 40 odim = 40 ilens = [10, 5] olens = [20, 15] batch = prepare_inputs(idim, odim, ilens, olens, model_args["spk_embed_dim"]) # define model model = Transformer(idim, odim, Namespace(**model_args)) optimizer = torch.optim.Adam(model.parameters()) # trainable loss = model(**batch).mean() optimizer.zero_grad() loss.backward() optimizer.step() # check gradient of ScaledPositionalEncoding if model.use_scaled_pos_enc: assert model.encoder.embed[-1].alpha.grad is not None assert model.decoder.embed[-1].alpha.grad is not None # decodable model.eval() with torch.no_grad(): if model_args["spk_embed_dim"] is None: spemb = None else: spemb = batch["spembs"][0] model.inference( batch["xs"][0][:batch["ilens"][0]], Namespace(**inference_args), spemb=spemb, ) model.calculate_all_attentions(**batch)
def test_forward_and_inference_are_equal(model_dict): # make args model_args = make_transformer_args(dprenet_dropout_rate=0.0, **model_dict) # setup batch idim = 40 odim = 40 ilens = [60] olens = [60] batch = prepare_inputs(idim, odim, ilens, olens) xs = batch["xs"] ilens = batch["ilens"] ys = batch["ys"] olens = batch["olens"] # define model model = Transformer(idim, odim, Namespace(**model_args)) model.eval() # TODO(kan-bayashi): update following ugly part with torch.no_grad(): # --------- forward calculation --------- x_masks = model._source_mask(ilens) hs_fp, h_masks = model.encoder(xs, x_masks) if model.reduction_factor > 1: ys_in = ys[:, model.reduction_factor - 1::model.reduction_factor] olens_in = olens.new( [olen // model.reduction_factor for olen in olens]) else: ys_in, olens_in = ys, olens ys_in = model._add_first_frame_and_remove_last_frame(ys_in) y_masks = model._target_mask(olens_in) zs, _ = model.decoder(ys_in, y_masks, hs_fp, h_masks) before_outs = model.feat_out(zs).view(zs.size(0), -1, model.odim) logits = model.prob_out(zs).view(zs.size(0), -1) after_outs = before_outs + model.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) # --------- forward calculation --------- # --------- inference calculation --------- hs_ir, _ = model.encoder(xs, None) maxlen = ys_in.shape[1] minlen = ys_in.shape[1] idx = 0 # this is the inferene calculation but we use groundtruth to check the behavior ys_in_ = ys_in[0, idx].view(1, 1, model.odim) np.testing.assert_array_equal( ys_in_.new_zeros(1, 1, model.odim).detach().cpu().numpy(), ys_in_.detach().cpu().numpy(), ) outs, probs = [], [] while True: idx += 1 y_masks = subsequent_mask(idx).unsqueeze(0) z = model.decoder.forward_one_step(ys_in_, y_masks, hs_ir)[0] # (B, idx, adim) outs += [model.feat_out(z).view(1, -1, model.odim)] # [(1, r, odim), ...] probs += [torch.sigmoid(model.prob_out(z))[0]] # [(r), ...] if idx >= maxlen: if idx < minlen: continue outs = torch.cat(outs, dim=1).transpose( 1, 2) # (1, L, odim) -> (1, odim, L) if model.postnet is not None: outs = outs + model.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) break ys_in_ = torch.cat((ys_in_, ys_in[0, idx].view(1, 1, model.odim)), dim=1) # (1, idx + 1, odim) # --------- inference calculation --------- # check both are equal np.testing.assert_array_almost_equal( hs_fp.detach().cpu().numpy(), hs_ir.detach().cpu().numpy(), ) np.testing.assert_array_almost_equal( after_outs.squeeze(0).detach().cpu().numpy(), outs.detach().cpu().numpy(), ) np.testing.assert_array_almost_equal( torch.sigmoid(logits.squeeze(0)).detach().cpu().numpy(), probs.detach().cpu().numpy(), )