Exemplo n.º 1
0
def test_encoder_cache(normalize_before):
    adim = 4
    idim = 5
    encoder = Encoder(
        idim=idim,
        attention_dim=adim,
        linear_units=3,
        num_blocks=2,
        normalize_before=normalize_before,
        dropout_rate=0.0,
        input_layer="embed")
    elayer = encoder.encoders[0]
    x = torch.randn(2, 5, adim)
    mask = subsequent_mask(x.shape[1]).unsqueeze(0)
    prev_mask = mask[:, :-1, :-1]
    encoder.eval()
    with torch.no_grad():
        # layer-level test
        y = elayer(x, mask, None)[0]
        cache = elayer(x[:, :-1], prev_mask, None)[0]
        y_fast = elayer(x, mask, cache=cache)[0]
        numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=1e-5)

        # encoder-level test
        x = torch.randint(0, idim, x.shape[:2])
        y = encoder.forward_one_step(x, mask)[0]
        y_, _, cache = encoder.forward_one_step(x[:, :-1], prev_mask)
        y_fast, _, _ = encoder.forward_one_step(x, mask, cache=cache)
        numpy.testing.assert_allclose(y.numpy(), y_fast.numpy(), rtol=1e-5)
Exemplo n.º 2
0
        decoder = Decoder(
            odim=odim,
            attention_dim=adim,
            linear_units=3,
            num_blocks=2,
            dropout_rate=0.0)
        decoder.eval()
    else:
        encoder = Encoder(
            idim=odim,
            attention_dim=adim,
            linear_units=3,
            num_blocks=2,
            dropout_rate=0.0,
            input_layer="embed")
        encoder.eval()

    xlen = 100
    xs = torch.randint(0, odim, (1, xlen))
    memory = torch.randn(2, 500, adim)
    mask = subsequent_mask(xlen).unsqueeze(0)

    result = {"cached": [], "baseline": []}
    n_avg = 10
    for key, value in result.items():
        cache = None
        print(key)
        for i in range(xlen):
            x = xs[:, :i + 1]
            m = mask[:, :i + 1, :i + 1]
            start = time()