Exemplo n.º 1
0
def generate_qkvm():
    B = np.random.randint(5, 10)
    T = np.random.randint(20, 30)
    H = np.random.randint(100, 200)

    B = 3
    T = 4
    H = 5

    q = torch.rand(B, H)
    k = torch.rand(B, T, H)
    v = torch.rand(B, T, H)

    lengths = torch.randint(1, T, size=(B, ))
    lengths[torch.randint(0, B, size=(B // 2, ))] = T

    m = sequence_mask(lengths)

    ks = []
    vs = []
    ms = []
    for i, l in enumerate(lengths):
        k[i, l:] = 0
        v[i, l:] = 0
        ks.append(k[i, :l].unsqueeze(0))
        vs.append(v[i, :l].unsqueeze(0))
        ms.append(m[i, :l].unsqueeze(0))

    qs = [x.unsqueeze(0) for x in q]

    return qkvm(q, k, v, m), qkvm(qs, ks, vs, ms)
Exemplo n.º 2
0
def test_attention_masked_valid_probs(lengths):
    bsz, lengths, seq_len = lengths
    mask = sequence_mask(lengths)
    scores = torch.rand(bsz, seq_len)
    score_mask = scores.masked_fill(mask, -1e9)
    attention_weights = F.softmax(score_mask, dim=1)
    for row in attention_weights:
        np.testing.assert_allclose(torch.sum(row).numpy(), 1.0, rtol=1e-5)
Exemplo n.º 3
0
def test_mask_valid_locs(lengths):
    bsz, lengths, seq_len = lengths
    mask = sequence_mask(lengths)
    np_mask = np.zeros((bsz, seq_len))
    for i in range(bsz):
        for j in range(seq_len):
            if j < lengths.data[i]:
                np_mask[i, j] = 1
    np.testing.assert_allclose(mask.data.numpy(), np_mask)
Exemplo n.º 4
0
def _call_model(m, inputs):
    from eight_mile.pytorch.layers import sequence_mask
    m.eval()
    lengths = inputs.get('lengths')
    x = m.embeddings(inputs)
    max_seqlen = x.shape[1]
    mask = sequence_mask(lengths,
                         max_seqlen).to(x.device).unsqueeze(1).unsqueeze(1)
    return m.generator((x, mask))
Exemplo n.º 5
0
def test_mask_mxlen(lengths):
    bsz, lengths, seq_len = lengths
    extra = np.random.randint(2, 11)
    mask = sequence_mask(lengths, seq_len + extra)
    np_mask = np.zeros((bsz, seq_len + extra))
    for i in range(bsz):
        for j in range(seq_len + extra):
            if j < lengths.data[i]:
                np_mask[i, j] = 1
    np.testing.assert_allclose(mask.data.numpy(), np_mask)
Exemplo n.º 6
0
def test_attention_masked_ignores_pad(lengths):
    bsz, lengths, seq_len = lengths
    mask = sequence_mask(lengths)
    scores = torch.rand(bsz, seq_len)
    score_mask = scores.masked_fill(mask, -1e9)
    attention_weights = F.softmax(score_mask, dim=1)
    for row, length in zip(attention_weights, lengths):
        if length.item() == seq_len:
            continue
        masked = row[: length.item()]
        np.testing.assert_allclose(masked.data.numpy(), 0.0)
Exemplo n.º 7
0
def attn_values_seq_mask(attn, qkv):
    q, k, v = qkv
    B, H, T, _ = q.shape
    q = q.zero_()
    lens = torch.from_numpy(np.random.randint(1, T, size=B))
    mask = sequence_mask(lens, T).unsqueeze(1).unsqueeze(1)
    res = attn((q, k, v, mask))
    res = res.numpy()
    gold = v.numpy()
    for b in range(B):
        for h in range(H):
            for t in range(T):
                np.testing.assert_allclose(res[b, h, t, :],
                                           np.mean(gold[:, :, :lens[b], :],
                                                   axis=2)[b, h, :],
                                           atol=1e-5)
Exemplo n.º 8
0
def test_windowed_ra():
    num_heads = 4
    d_model = 64
    rpr_k = 1
    batchsize = 2
    nctx = 12
    d_k = d_model // num_heads

    old = SeqScaledDotProductRelativeAttention(pdrop=0.)
    new = SeqScaledWindowedRelativeAttention(pdrop=0.)

    rpr_key_emb = torch.nn.Embedding(2 * rpr_k + 1, d_k)
    rpr_value_emb = torch.nn.Embedding(2 * rpr_k + 1, d_k)

    Q = torch.randn(batchsize, num_heads, nctx, d_k)
    K = torch.randn(batchsize, num_heads, nctx, d_k)
    V = torch.randn(batchsize, num_heads, nctx, d_k)
    lengths = torch.randint(2, nctx, [
        batchsize,
    ])
    seq_mask = sequence_mask(lengths, max_len=nctx)
    in_mask = seq_mask.unsqueeze(1).unsqueeze(1)  # [B, 1, 1, T]
    out_mask = seq_mask.unsqueeze(1).unsqueeze(-1)  # [B, 1, T, 1]

    # manually create a ra_mask to prevent attention beyond rpr_k
    ones = torch.ones(nctx, nctx)
    ra_mask = torch.triu(ones, diagonal=-rpr_k) - torch.triu(
        ones, diagonal=rpr_k + 1)
    mask = in_mask * ra_mask.unsqueeze(0).unsqueeze(0)
    rpr_key_old, rpr_value_old = make_rpr(rpr_key_emb, rpr_value_emb, rpr_k,
                                          nctx)
    old.eval()
    out_old = old((Q, K, V, rpr_key_old, rpr_value_old, mask))
    out_old = out_old.masked_fill(out_mask == False, 1).detach().numpy()
    print(out_old.shape)

    # using the windowed relative attention with the original sequence mask
    rpr_key_new, rpr_value_new = unfold_rpr(rpr_key_emb, rpr_value_emb, rpr_k)
    new.eval()
    out_new = new((Q, K, V, rpr_key_new, rpr_value_new, in_mask))
    out_new = out_new.masked_fill(out_mask == False, 1).detach().numpy()
    print(out_new.shape)

    assert np.allclose(out_old, out_new, atol=1e-6)
Exemplo n.º 9
0
def test_seq_mask_valid_count(lengths):
    bsz, lengths, _ = lengths
    mask = sequence_mask(lengths)
    gold = lengths.sum()
    assert mask.sum() == gold.sum()
Exemplo n.º 10
0
def test_mask_shape(lengths):
    bsz, lengths, seq_len = lengths
    mask = sequence_mask(lengths)
    assert mask.size(0) == bsz
    assert mask.size(1) == seq_len