def test_discrete_barycenter_grid(self, lse_mode, debiased, epsilon): """Tests the discrete barycenters on a 5x5x5 grid. Puts two masses on opposing ends of the hypercube with small noise in between. Check that their W barycenter sits (mostly) at the middle of the hypercube (e.g. index (5x5x5-1)/2) Args: lse_mode: bool, lse or scaling computations. debiased: bool, use (or not) debiasing as proposed in https://arxiv.org/abs/2006.02575 epsilon: float, regularization parameter """ size = jnp.array([5, 5, 5]) grid_3d = grid.Grid(grid_size=size, epsilon=epsilon) a = jnp.ones(size) b = jnp.ones(size) a = a.ravel() b = b.ravel() a = jax.ops.index_update(a, 0, 10000) b = jax.ops.index_update(b, -1, 10000) a = a / jnp.sum(a) b = b / jnp.sum(b) threshold = 1e-2 _, _, bar, errors = db.discrete_barycenter(grid_3d, a=jnp.stack((a, b)), threshold=threshold, lse_mode=lse_mode, debiased=debiased) self.assertGreater(bar[(jnp.prod(size) - 1) // 2], 0.7) self.assertGreater(1, bar[(jnp.prod(size) - 1) // 2]) err = errors[jnp.isfinite(errors)][-1] self.assertGreater(threshold, err)
def reg_ot(x): geom = grid.Grid(x=x, epsilon=1.0) return sinkhorn.sinkhorn(geom, a=a, b=b, threshold=0.1, lse_mode=lse_mode).reg_ot_cost
def test_autograd_sinkhorn_x_grid(self, lse_mode): """Test gradient w.r.t. probability weights.""" eps = 1e-4 # perturbation magnitude keys = jax.random.split(self.rng, 3) x = (jnp.array([.0, 1.0], dtype=jnp.float32), jnp.array([.3, .4, .7], dtype=jnp.float32), jnp.array([1.0, 1.3, 2.4, 3.7], dtype=jnp.float32)) grid_size = tuple([xs.shape[0] for xs in x]) a = jax.random.uniform(keys[0], grid_size) + 1 b = jax.random.uniform(keys[1], grid_size) + 1 a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) geom = grid.Grid(x=x, epsilon=1) def reg_ot(a, b): return sinkhorn.sinkhorn(geom, a=a, b=b, threshold=0.001, lse_mode=lse_mode).reg_ot_cost reg_ot_and_grad = jax.jit(jax.value_and_grad(reg_ot)) _, grad_reg_ot = reg_ot_and_grad(a, b) delta = jax.random.uniform(keys[2], grid_size).ravel() delta = delta - jnp.mean(delta) # center perturbation reg_ot_delta_plus = reg_ot(a + eps * delta, b) reg_ot_delta_minus = reg_ot(a - eps * delta, b) delta_dot_grad = jnp.sum(delta * grad_reg_ot) self.assertAllClose(delta_dot_grad, (reg_ot_delta_plus - reg_ot_delta_minus) / (2 * eps), rtol=1e-03, atol=1e-02)
def test_autograd_sinkhorn_grid(self, lse_mode): """Test gradient w.r.t. probability weights.""" eps = 1e-3 # perturbation magnitude keys = jax.random.split(self.rng, 3) grid_size = (2, 3, 4) a = jax.random.uniform(keys[0], grid_size) + 1.0 b = jax.random.uniform(keys[1], grid_size) + 1.0 a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) geom = grid.Grid(grid_size=grid_size, epsilon=0.2) def reg_ot(a, b): return sinkhorn.sinkhorn( geom, a=a, b=b, threshold=0.1, lse_mode=lse_mode, implicit_differentiation=False).reg_ot_cost reg_ot_and_grad = jax.jit(jax.value_and_grad(reg_ot)) _, grad_reg_ot = reg_ot_and_grad(a, b) delta = jax.random.uniform(keys[2], grid_size).ravel() delta = delta - jnp.mean(delta) # center perturbation reg_ot_delta_plus = reg_ot(a + eps * delta, b) reg_ot_delta_minus = reg_ot(a - eps * delta, b) delta_dot_grad = jnp.sum(delta * grad_reg_ot) self.assertAllClose(delta_dot_grad, (reg_ot_delta_plus - reg_ot_delta_minus) / (2 * eps), rtol=1e-03, atol=1e-02)
def test_separable_grid(self, lse_mode): """Two histograms in a grid of size 5 x 6 x 7 in the hypercube^3.""" grid_size = (5, 6, 7) keys = jax.random.split(self.rng, 2) a = jax.random.uniform(keys[0], grid_size) b = jax.random.uniform(keys[1], grid_size) # adding zero weights to test proper handling, then ravel. a = jax.ops.index_update(a, 0, 0).ravel() a = a / jnp.sum(a) b = jax.ops.index_update(b, 3, 0).ravel() b = b / jnp.sum(b) threshold = 0.01 geom = grid.Grid(grid_size=grid_size, epsilon=0.1) errors = sinkhorn.sinkhorn( geom, a=a, b=b, threshold=threshold, lse_mode=lse_mode).errors err = errors[jnp.isfinite(errors)][-1] self.assertGreater(threshold, err)
def test_apply_cost(self): grid_size = (5, 6, 7) geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1) x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]] xyz = jnp.stack([ jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1), jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1), jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1), ]).transpose() geom_mat = pointcloud.PointCloud(xyz, xyz, epsilon=0.1) vec = jax.random.uniform(self.rng, grid_size).ravel() self.assertAllClose(geom_mat.apply_cost(vec), geom_grid.apply_cost(vec)) self.assertAllClose( geom_grid.apply_cost(vec)[0, :], np.dot(geom_mat.cost_matrix, vec))
def test_grid_vs_euclidean(self, lse_mode): grid_size = (5, 6, 7) keys = jax.random.split(self.rng, 2) a = jax.random.uniform(keys[0], grid_size) b = jax.random.uniform(keys[1], grid_size) a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) epsilon = 0.1 geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon) x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]] xyz = jnp.stack([ jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1), jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1), jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1), ]).transpose() geometry_mat = pointcloud.PointCloud(xyz, xyz, epsilon=epsilon) out_mat = sinkhorn.sinkhorn(geometry_mat, a=a, b=b, lse_mode=lse_mode) out_grid = sinkhorn.sinkhorn(geometry_grid, a=a, b=b, lse_mode=lse_mode) self.assertAllClose(out_mat.reg_ot_cost, out_grid.reg_ot_cost)
def test_apply_transport_grid(self, lse_mode): grid_size = (5, 6, 7) keys = jax.random.split(self.rng, 3) a = jax.random.uniform(keys[0], grid_size) b = jax.random.uniform(keys[1], grid_size) a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1) x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]] xyz = jnp.stack([ jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1), jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1), jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1), ]).transpose() geom_mat = pointcloud.PointCloud(xyz, xyz, epsilon=0.1) sink_mat = sinkhorn.sinkhorn(geom_mat, a=a, b=b, lse_mode=lse_mode) sink_grid = sinkhorn.sinkhorn(geom_grid, a=a, b=b, lse_mode=lse_mode) batch_a = 3 batch_b = 4 vec_a = jax.random.normal(keys[4], [batch_a, np.prod(np.array(grid_size))]) vec_b = jax.random.normal(keys[4], [batch_b, np.prod(grid_size)]) vec_a = vec_a / jnp.sum(vec_a, axis=1)[:, jnp.newaxis] vec_b = vec_b / jnp.sum(vec_b, axis=1)[:, jnp.newaxis] mat_transport_t_vec_a = geom_mat.apply_transport_from_potentials( sink_mat.f, sink_mat.g, vec_a, axis=0) mat_transport_vec_b = geom_mat.apply_transport_from_potentials( sink_mat.f, sink_mat.g, vec_b, axis=1) grid_transport_t_vec_a = geom_grid.apply_transport_from_potentials( sink_grid.f, sink_grid.g, vec_a, axis=0) grid_transport_vec_b = geom_grid.apply_transport_from_potentials( sink_grid.f, sink_grid.g, vec_b, axis=1) self.assertAllClose(mat_transport_t_vec_a, grid_transport_t_vec_a) self.assertAllClose(mat_transport_vec_b, grid_transport_vec_b) self.assertIsNot(jnp.any(jnp.isnan(mat_transport_t_vec_a)), True)