def test_simplex_project(): res = utils.simplex_project(np.array([0, 0, 0])) assert np.allclose(res, [1 / 3] * 3), \ "projecting [0, 0, 0] didn't result in uniform" res = utils.simplex_project(np.array([1.2, 1.4])) assert np.allclose(res, [.4, .6]), \ "simplex project didn't return correct result" res = utils.simplex_project(np.array([-0.1, 0.8])) assert np.allclose(res, [0.05, 0.95]), \ "simplex project didn't return correct result"
def test_simplex_project(): """Test simplex project""" res = utils.simplex_project(np.array([0, 0, 0])) assert np.allclose(res, [1 / 3] * 3), \ "projecting [0, 0, 0] didn't result in uniform" res = utils.simplex_project(np.array([1.2, 1.4])) assert np.allclose(res, [.4, .6]), \ "simplex project didn't return correct result" res = utils.simplex_project(np.array([-0.1, 0.8])) assert np.allclose(res, [0.05, 0.95]), \ "simplex project didn't return correct result"
def test_ndim_fixed_point(dim): """Test that it computes a fixed point for arbitrary dimensional cycles""" start = utils.simplex_project(np.random.rand(dim)) res = fixedpoint.fixed_point(lambda x: np.roll(x, 1), start, disc=100) assert np.all(res >= 0) assert np.isclose(res.sum(), 1) assert np.all(np.abs(res - 1 / dim) <= 0.01)
def test_fixed_point(dim, rate, tol, disc_mult): start = utils.simplex_project(np.random.rand(dim)) target, func = simple_fixed_point(dim, rate) print(target) res = fixedpoint.fixed_point(func, start, tol=tol, init_disc=dim * disc_mult) print() assert np.all(np.abs(res - target) <= tol)
def simple_fixed_point(dim, rate): target = utils.simplex_project(np.random.rand(dim)) def func(mix): img = mix.copy() img += rate * (target - img) return img return target, func
def labeled_subsimplex(label_func, init, tol=1e-3, stop=None, init_disc=1): """Find approximate center of a fully labeled subsimplex Parameters ---------- label_func : ndarray -> int A proper lableing function. A labeling function takes an element of the d-simplex and returns a label in [0, d). It is proper if the label always coresponds to a dimension in support. init : ndarray An initial guess for where the fully labeled element might be. This will be projected onto the simplex if it is not already. tol : float, optional The tolerance for the returned value. stop : ndarray -> bool Function of the current simplex that returns true when the search should stop. By default, this stops when the sub-simplex has sides smaller than tol. init_disc : int, optional The initial discretization amount for the mixture. The initial discretization relates to the amount of possible starting points, which may achieve different subsimplicies. This setting this higher may make finding one subsimplex slower, but allow the possibility of finding more. This function uses the `max(init.size, init_disc, 8)`. Notes ----- Implementation from [1]_ and [2]_ .. [1] Kuhn and Mackinnon 1975. Sandwich Method for Finding Fixed Points. .. [2] Kuhn 1968. Simplicial Approximation Of Fixed Points. """ k = max(init.size, init_disc, 8) # XXX There's definitely a more principled way to set `2 / k` thresh = round(1 / tol) * 2 + 1 if stop is None: def stop(_): return k > thresh disc_simplex = _discretize_mixture(utils.simplex_project(init), k) sub_simplex = disc_simplex / k while not stop(sub_simplex): disc_simplex = _sandwich(label_func, disc_simplex) * 2 k = (k - 1) * 2 sub_simplex = disc_simplex / k return sub_simplex
def test_rps_fixed_point(tol): """Test that it computes a fixed point for bad shapley triangles""" # pytest: disable-msg=invalid-name start = utils.simplex_project(np.random.rand(3)) weights = 1 + 3 * np.random.random(3) vara, varb, varc = weights expected = np.linalg.solve( [[0, -vara, 1, 1], [1, 0, -varb, 1], [-varc, 1, 0, 1], [1, 1, 1, 0]], [0, 0, 0, 1])[:-1] def func(inp): """Example fixed point function""" inter = np.roll(inp, 1) - weights * np.roll(inp, -1) res = np.maximum(0, inter - inter.dot(inp)) + inp return res / res.sum() res = fixedpoint.fixed_point(func, start, tol=tol) assert np.all(np.abs(res - expected) <= tol)
def test_simplex_project_random(array): """Test simplex project on random arrays""" simp = utils.simplex_project(array) assert simp.shape == array.shape assert np.all(simp >= 0) assert np.allclose(simp.sum(-1), 1)