def test_transformer_gpu_trainable_and_decodable(model_dict): # make args model_args = make_transformer_args(**model_dict) inference_args = make_inference_args() idim = 40 odim = 40 ilens = [10, 5, 10, 5] olens = [20, 15, 20, 15] device = torch.device("cuda") batch = prepare_inputs(idim, odim, ilens, olens, model_args["spk_embed_dim"], device=device) # define model model = Transformer(idim, odim, Namespace(**model_args)) model.to(device) optimizer = torch.optim.Adam(model.parameters()) # trainable loss = model(**batch).mean() optimizer.zero_grad() loss.backward() optimizer.step() # check gradient of ScaledPositionalEncoding if model.use_scaled_pos_enc: assert model.encoder.embed[-1].alpha.grad is not None assert model.decoder.embed[-1].alpha.grad is not None # decodable model.eval() with torch.no_grad(): if model_args["spk_embed_dim"] is None: spemb = None else: spemb = batch["spembs"][0] model.inference( batch["xs"][0][:batch["ilens"][0]], Namespace(**inference_args), spemb=spemb, ) model.calculate_all_attentions(**batch)
def test_forward_and_inference_are_equal(model_dict): # make args model_args = make_transformer_args(dprenet_dropout_rate=0.0, **model_dict) # setup batch idim = 40 odim = 40 ilens = [60] olens = [60] batch = prepare_inputs(idim, odim, ilens, olens) xs = batch["xs"] ilens = batch["ilens"] ys = batch["ys"] olens = batch["olens"] # define model model = Transformer(idim, odim, Namespace(**model_args)) model.eval() # TODO(kan-bayashi): update following ugly part with torch.no_grad(): # --------- forward calculation --------- x_masks = model._source_mask(ilens) hs_fp, h_masks = model.encoder(xs, x_masks) if model.reduction_factor > 1: ys_in = ys[:, model.reduction_factor - 1::model.reduction_factor] olens_in = olens.new( [olen // model.reduction_factor for olen in olens]) else: ys_in, olens_in = ys, olens ys_in = model._add_first_frame_and_remove_last_frame(ys_in) y_masks = model._target_mask(olens_in) zs, _ = model.decoder(ys_in, y_masks, hs_fp, h_masks) before_outs = model.feat_out(zs).view(zs.size(0), -1, model.odim) logits = model.prob_out(zs).view(zs.size(0), -1) after_outs = before_outs + model.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) # --------- forward calculation --------- # --------- inference calculation --------- hs_ir, _ = model.encoder(xs, None) maxlen = ys_in.shape[1] minlen = ys_in.shape[1] idx = 0 # this is the inferene calculation but we use groundtruth to check the behavior ys_in_ = ys_in[0, idx].view(1, 1, model.odim) np.testing.assert_array_equal( ys_in_.new_zeros(1, 1, model.odim).detach().cpu().numpy(), ys_in_.detach().cpu().numpy(), ) outs, probs = [], [] while True: idx += 1 y_masks = subsequent_mask(idx).unsqueeze(0) z = model.decoder.forward_one_step(ys_in_, y_masks, hs_ir)[0] # (B, idx, adim) outs += [model.feat_out(z).view(1, -1, model.odim)] # [(1, r, odim), ...] probs += [torch.sigmoid(model.prob_out(z))[0]] # [(r), ...] if idx >= maxlen: if idx < minlen: continue outs = torch.cat(outs, dim=1).transpose( 1, 2) # (1, L, odim) -> (1, odim, L) if model.postnet is not None: outs = outs + model.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) break ys_in_ = torch.cat((ys_in_, ys_in[0, idx].view(1, 1, model.odim)), dim=1) # (1, idx + 1, odim) # --------- inference calculation --------- # check both are equal np.testing.assert_array_almost_equal( hs_fp.detach().cpu().numpy(), hs_ir.detach().cpu().numpy(), ) np.testing.assert_array_almost_equal( after_outs.squeeze(0).detach().cpu().numpy(), outs.detach().cpu().numpy(), ) np.testing.assert_array_almost_equal( torch.sigmoid(logits.squeeze(0)).detach().cpu().numpy(), probs.detach().cpu().numpy(), )