Ejemplo n.º 1
0
def test_vec_log_sum_exp_batch_stable():
    h = np.random.randint(22, 41)
    i1 = torch.rand(1, h, h)
    i2 = torch.rand(1, h, h)
    i = torch.cat([i1, i2], dim=0)
    lse1 = vec_log_sum_exp(i1, 2)
    lse2 = vec_log_sum_exp(i2, 2)
    one_x_one = torch.cat([lse1, lse2], dim=0)
    lse = vec_log_sum_exp(i, 2)
    np.testing.assert_allclose(one_x_one.numpy(), lse.numpy())
Ejemplo n.º 2
0
def test_vec_log_sum_exp_shape():
    dim = torch.randint(0, 3, (1, )).item()
    shape = torch.randint(1, 21, (3, ))
    in_ = torch.rand(*shape)
    out = vec_log_sum_exp(in_, dim)
    shape[dim] = 1
    for i in range(len(shape)):
        assert out.size(i) == shape[i]
Ejemplo n.º 3
0
def test_vec_log_sum_exp():
    vec = torch.rand(1, np.random.randint(5, 31))
    ours = vec_log_sum_exp(vec, 1).squeeze()
    xs = {}
    for i in range(vec.size(1)):
        xs[i] = vec[0, i].item()
    gold = explicit_log_sum_exp(xs)
    np.testing.assert_allclose(ours, gold, rtol=1e-6)
Ejemplo n.º 4
0
def test_vec_log_sum_exp_ones():
    l = np.random.randint(1, 21)
    in_ = torch.ones(1, l)
    lse = vec_log_sum_exp(in_, 1).squeeze()
    np.testing.assert_allclose(lse.detach().numpy(), math.log(l * math.e))