def test_vec_log_sum_exp_batch_stable(): h = np.random.randint(22, 41) i1 = torch.rand(1, h, h) i2 = torch.rand(1, h, h) i = torch.cat([i1, i2], dim=0) lse1 = vec_log_sum_exp(i1, 2) lse2 = vec_log_sum_exp(i2, 2) one_x_one = torch.cat([lse1, lse2], dim=0) lse = vec_log_sum_exp(i, 2) np.testing.assert_allclose(one_x_one.numpy(), lse.numpy())
def test_vec_log_sum_exp_shape(): dim = torch.randint(0, 3, (1, )).item() shape = torch.randint(1, 21, (3, )) in_ = torch.rand(*shape) out = vec_log_sum_exp(in_, dim) shape[dim] = 1 for i in range(len(shape)): assert out.size(i) == shape[i]
def test_vec_log_sum_exp(): vec = torch.rand(1, np.random.randint(5, 31)) ours = vec_log_sum_exp(vec, 1).squeeze() xs = {} for i in range(vec.size(1)): xs[i] = vec[0, i].item() gold = explicit_log_sum_exp(xs) np.testing.assert_allclose(ours, gold, rtol=1e-6)
def test_vec_log_sum_exp_ones(): l = np.random.randint(1, 21) in_ = torch.ones(1, l) lse = vec_log_sum_exp(in_, 1).squeeze() np.testing.assert_allclose(lse.detach().numpy(), math.log(l * math.e))