def test_decoder_cache(normalize_before):
    adim = 4
    odim = 5
    decoder = Decoder(
        odim=odim,
        attention_dim=adim,
        linear_units=3,
        num_blocks=2,
        normalize_before=normalize_before,
        dropout_rate=0.0)
    dlayer = decoder.decoders[0]
    memory = torch.randn(2, 5, adim)

    x = torch.randn(2, 5, adim) * 100
    mask = subsequent_mask(x.shape[1]).unsqueeze(0)
    prev_mask = mask[:, :-1, :-1]
    decoder.eval()
    with torch.no_grad():
        # layer-level test
        y = dlayer(x, mask, memory, None)[0]
        cache = dlayer(x[:, :-1], prev_mask, memory, None)[0]
        y_fast = dlayer(x, mask, memory, None, cache=cache)[0]
        numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=1e-5)

        # decoder-level test
        x = torch.randint(0, odim, x.shape[:2])
        y, _ = decoder.forward_one_step(x, mask, memory)
        y_, cache = decoder.forward_one_step(x[:, :-1], prev_mask, memory, cache=decoder.init_state())
        y_fast, _ = decoder.forward_one_step(x, mask, memory, cache=cache)
        numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=1e-5)
    # benchmark with synth dataset
    from time import time

    import matplotlib.pyplot as plt

    adim = 4
    odim = 5
    model = "decoder"
    if model == "decoder":
        decoder = Decoder(
            odim=odim,
            attention_dim=adim,
            linear_units=3,
            num_blocks=2,
            dropout_rate=0.0)
        decoder.eval()
    else:
        encoder = Encoder(
            idim=odim,
            attention_dim=adim,
            linear_units=3,
            num_blocks=2,
            dropout_rate=0.0,
            input_layer="embed")
        encoder.eval()

    xlen = 100
    xs = torch.randint(0, odim, (1, xlen))
    memory = torch.randn(2, 500, adim)
    mask = subsequent_mask(xlen).unsqueeze(0)