def test_attention_values(batch_size, n_q, n_c, d, p): q = np.random.normal(0, 1, (batch_size, n_q, d)) c = np.random.normal(0, 1, (batch_size, n_c, d)) v = np.random.normal(0, 1, (batch_size, n_c, p)) w_out, z_out = attention.attend( Volatile(torch.from_numpy(q)), Volatile(torch.from_numpy(c)), value=Volatile(torch.from_numpy(v)), return_weight=True) w_out = w_out.data.numpy() z_out = z_out.data.numpy() assert w_out.shape == (batch_size, n_q, n_c) assert z_out.shape == (batch_size, n_q, p) for i in range(batch_size): for j in range(n_q): s = [np.dot(q[i,j], c[i,k]) for k in range(n_c)] max_s = max(s) exp_s = [np.exp(si - max_s) for si in s] sum_exp_s = sum(exp_s) w_ref = [ei / sum_exp_s for ei in exp_s] assert np.allclose(w_ref, w_out[i,j]) z_ref = sum(w_ref[k] * v[i,k] for k in range(n_c)) assert np.allclose(z_ref, z_out[i,j])
def test_attention_masked(batch_size, n_q, n_c, d, context_sizes): q = np.random.normal(0, 1, (batch_size, n_q, d)) c = np.random.normal(0, 1, (batch_size, n_c, d)) w_out, z_out = attention.attend( Volatile(torch.from_numpy(q)), Volatile(torch.from_numpy(c)), context_sizes=context_sizes, return_weight=True) w_out = w_out.data.numpy() z_out = z_out.data.numpy() assert w_out.shape == (batch_size, n_q, n_c) assert z_out.shape == (batch_size, n_q, d) w_checked = np.zeros((batch_size, n_q, n_c), dtype=int) z_checked = np.zeros((batch_size, n_q, d), dtype=int) for i in range(batch_size): for j in range(n_q): n = context_sizes[i] if context_sizes is not None else n_c s = [np.dot(q[i,j], c[i,k]) for k in range(n)] max_s = max(s) exp_s = [np.exp(sk - max_s) for sk in s] sum_exp_s = sum(exp_s) w_ref = [ek / sum_exp_s for ek in exp_s] for k in range(n_c): if k < n: assert np.allclose(w_ref[k], w_out[i,j,k]) w_checked[i,j,k] = 1 else: assert np.allclose(0, w_out[i,j,k]) w_checked[i,j,k] = 1 z_ref = sum(w_ref[k] * c[i,k] for k in range(n)) for k in range(d): assert np.allclose(z_ref[k], z_out[i,j,k]) z_checked[i,j,k] = 1 assert np.all(w_checked == 1) assert np.all(z_checked == 1)