def test_attention_values(batch_size, n_q, n_c, d, p):
    q = np.random.normal(0, 1, (batch_size, n_q, d))
    c = np.random.normal(0, 1, (batch_size, n_c, d))
    v = np.random.normal(0, 1, (batch_size, n_c, p))
    w_out, z_out = attention.attend(
            Volatile(torch.from_numpy(q)),
            Volatile(torch.from_numpy(c)),
            value=Volatile(torch.from_numpy(v)), return_weight=True)
    w_out = w_out.data.numpy()
    z_out = z_out.data.numpy()

    assert w_out.shape == (batch_size, n_q, n_c)
    assert z_out.shape == (batch_size, n_q, p)

    for i in range(batch_size):
        for j in range(n_q):
            s = [np.dot(q[i,j], c[i,k]) for k in range(n_c)]
            max_s = max(s)
            exp_s = [np.exp(si - max_s) for si in s]
            sum_exp_s = sum(exp_s)

            w_ref = [ei / sum_exp_s for ei in exp_s]
            assert np.allclose(w_ref, w_out[i,j])

            z_ref = sum(w_ref[k] * v[i,k] for k in range(n_c))
            assert np.allclose(z_ref, z_out[i,j])
def test_attention_masked(batch_size, n_q, n_c, d, context_sizes):
    q = np.random.normal(0, 1, (batch_size, n_q, d))
    c = np.random.normal(0, 1, (batch_size, n_c, d))

    w_out, z_out = attention.attend(
        Volatile(torch.from_numpy(q)),
        Volatile(torch.from_numpy(c)),
        context_sizes=context_sizes, return_weight=True)
    w_out = w_out.data.numpy()
    z_out = z_out.data.numpy()

    assert w_out.shape == (batch_size, n_q, n_c)
    assert z_out.shape == (batch_size, n_q, d)

    w_checked = np.zeros((batch_size, n_q, n_c), dtype=int)
    z_checked = np.zeros((batch_size, n_q, d), dtype=int)

    for i in range(batch_size):
        for j in range(n_q):
            n = context_sizes[i] if context_sizes is not None else n_c

            s = [np.dot(q[i,j], c[i,k]) for k in range(n)]
            max_s = max(s)
            exp_s = [np.exp(sk - max_s) for sk in s]
            sum_exp_s = sum(exp_s)

            w_ref = [ek / sum_exp_s for ek in exp_s]
            for k in range(n_c):
                if k < n:
                    assert np.allclose(w_ref[k], w_out[i,j,k])
                    w_checked[i,j,k] = 1
                else:
                    assert np.allclose(0, w_out[i,j,k])
                    w_checked[i,j,k] = 1

            z_ref = sum(w_ref[k] * c[i,k] for k in range(n))
            for k in range(d):
                assert np.allclose(z_ref[k], z_out[i,j,k])
                z_checked[i,j,k] = 1

    assert np.all(w_checked == 1)
    assert np.all(z_checked == 1)