def test_mask_alphas(alpha): torch.manual_seed(42) x = torch.randn(2, 6) x[:, 3:] = -float("inf") x0 = x[:, :3] y = entmax_bisect(x, alpha) y0 = entmax_bisect(x0, alpha) y[:, :3] -= y0 assert torch.allclose(y, torch.zeros_like(y))
def test_entmax_correct_multiple_alphas(): n = 4 x = torch.randn(n, 6, dtype=torch.float64, requires_grad=True) alpha = 0.05 + 2.5*torch.rand((n, 1), dtype=torch.float64, requires_grad=True) p1 = entmax_bisect(x, alpha) p2_ = [ entmax_bisect(x[i].unsqueeze(0), alpha[i].item()).squeeze() for i in range(n) ] p2 = torch.stack(p2_) assert torch.allclose(p1, p2)
def test_arbitrary_dimension(dim): shape = [3, 4, 2, 5] X = torch.randn(*shape, dtype=torch.float64) alpha_shape = shape alpha_shape[dim] = 1 alphas = 0.05 + 2.5*torch.rand(alpha_shape, dtype=torch.float64) P = entmax_bisect(X, alpha=alphas, dim=dim) ranges = [ list(range(k)) if i != dim else [slice(None)] for i, k in enumerate(shape) ] for ix in product(*ranges): x = X[ix].unsqueeze(0) alpha = alphas[ix].item() p_true = entmax_bisect(x, alpha=alpha, dim=-1) assert torch.allclose(P[ix], p_true)
def test_entmax15(): x = 0.5 * torch.randn(4, 6, dtype=torch.float32) p1 = entmax15(x, 1) p2 = entmax_bisect(x, alpha=1.5) assert torch.sum((p1 - p2) ** 2) < 1e-7
def project(cls, X, alpha, n_iter): return entmax_bisect(X, alpha=alpha, n_iter=n_iter, ensure_sum_one=True)