def test_encoder_cache(normalize_before): adim = 4 idim = 5 encoder = Encoder( idim=idim, attention_dim=adim, linear_units=3, num_blocks=2, normalize_before=normalize_before, dropout_rate=0.0, input_layer="embed") elayer = encoder.encoders[0] x = torch.randn(2, 5, adim) mask = subsequent_mask(x.shape[1]).unsqueeze(0) prev_mask = mask[:, :-1, :-1] encoder.eval() with torch.no_grad(): # layer-level test y = elayer(x, mask, None)[0] cache = elayer(x[:, :-1], prev_mask, None)[0] y_fast = elayer(x, mask, cache=cache)[0] numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=1e-5) # encoder-level test x = torch.randint(0, idim, x.shape[:2]) y = encoder.forward_one_step(x, mask)[0] y_, _, cache = encoder.forward_one_step(x[:, :-1], prev_mask) y_fast, _, _ = encoder.forward_one_step(x, mask, cache=cache) numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=1e-5)
decoder = Decoder( odim=odim, attention_dim=adim, linear_units=3, num_blocks=2, dropout_rate=0.0) decoder.eval() else: encoder = Encoder( idim=odim, attention_dim=adim, linear_units=3, num_blocks=2, dropout_rate=0.0, input_layer="embed") encoder.eval() xlen = 100 xs = torch.randint(0, odim, (1, xlen)) memory = torch.randn(2, 500, adim) mask = subsequent_mask(xlen).unsqueeze(0) result = {"cached": [], "baseline": []} n_avg = 10 for key, value in result.items(): cache = None print(key) for i in range(xlen): x = xs[:, :i + 1] m = mask[:, :i + 1, :i + 1] start = time()