def test_decoder_cache(normalize_before): adim = 4 odim = 5 decoder = Decoder( odim=odim, attention_dim=adim, linear_units=3, num_blocks=2, normalize_before=normalize_before, dropout_rate=0.0) dlayer = decoder.decoders[0] memory = torch.randn(2, 5, adim) x = torch.randn(2, 5, adim) * 100 mask = subsequent_mask(x.shape[1]).unsqueeze(0) prev_mask = mask[:, :-1, :-1] decoder.eval() with torch.no_grad(): # layer-level test y = dlayer(x, mask, memory, None)[0] cache = dlayer(x[:, :-1], prev_mask, memory, None)[0] y_fast = dlayer(x, mask, memory, None, cache=cache)[0] numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=1e-5) # decoder-level test x = torch.randint(0, odim, x.shape[:2]) y, _ = decoder.forward_one_step(x, mask, memory) y_, cache = decoder.forward_one_step(x[:, :-1], prev_mask, memory, cache=decoder.init_state()) y_fast, _ = decoder.forward_one_step(x, mask, memory, cache=cache) numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=1e-5)
# benchmark with synth dataset from time import time import matplotlib.pyplot as plt adim = 4 odim = 5 model = "decoder" if model == "decoder": decoder = Decoder( odim=odim, attention_dim=adim, linear_units=3, num_blocks=2, dropout_rate=0.0) decoder.eval() else: encoder = Encoder( idim=odim, attention_dim=adim, linear_units=3, num_blocks=2, dropout_rate=0.0, input_layer="embed") encoder.eval() xlen = 100 xs = torch.randint(0, odim, (1, xlen)) memory = torch.randn(2, 500, adim) mask = subsequent_mask(xlen).unsqueeze(0)