def _init_geometry_gw(geom_x: geometry.Geometry, geom_y: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, epsilon: Union[epsilon_scheduler.Epsilon, float], loss: GWLoss, **kwargs) -> geometry.Geometry: """Initialises the cost matrix for the geometry object for GW. The equation follows Equation 6, Proposition 1 of http://proceedings.mlr.press/v48/peyre16.pdf. Args: geom_x: a Geometry object for the first view. geom_y: a second Geometry object for the second view. a: jnp.ndarray<float>[num_a,], weights. b: jnp.ndarray<float>[num_b,], weights. epsilon: a regularization parameter or a epsilon_scheduler.Epsilon object. loss: a GWLossFn object. **kwargs: additional kwargs to epsilon. Returns: A Geometry object for Gromov-Wasserstein. """ # Initialization of the transport matrix in the balanced case, following # http://proceedings.mlr.press/v48/peyre16.pdf ab = a[:, None] * b[None, :] marginal_x = ab.sum(1) marginal_y = ab.sum(0) marginal_dep_term = _marginal_dependent_cost(marginal_x, marginal_y, geom_x, geom_y, loss) tmp = geom_x.apply_cost(ab, axis=1, fn=loss.left_x) cost_matrix = marginal_dep_term - geom_y.apply_cost( tmp.T, axis=1, fn=loss.right_y).T return geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon, **kwargs)
def test_geom_vs_point_cloud(self): """Two point clouds vs. simple cost_matrix execution of sinkorn.""" geom = pointcloud.PointCloud(self.x, self.y) geom_2 = geometry.Geometry(geom.cost_matrix) f = sinkhorn.sinkhorn(geom, a=self.a, b=self.b).f f_2 = sinkhorn.sinkhorn(geom_2, a=self.a, b=self.b).f self.assertAllClose(f, f_2)
def _ot(X, Y, metric='euclidean', reg=0.1): """ Compute the optimal coupling between X and Y with entropic regularization using a OTT as a backend for acceleration. Parameters ---------- metric : str(optional) metric used to create transport cost matrix, \ see full list in scipy.spatial.distance.cdist doc reg : int (optional) level of entropic regularization Attributes ---------- R : scipy.sparse.csr_matrix Mixing matrix containing the optimal permutation """ n = len(X.T) cost_matrix = cdist(X.T, Y.T, metric=metric) geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=reg) P = transport.Transport(geom, max_iterations=1000, threshold=1e-3) P.solve() R = np.asarray(P.matrix * n) return R, R.dot(X.T).T
def loss_g(a, x, implicit=True): out = sinkhorn.sinkhorn(geometry.Geometry( cost_matrix=jnp.sum(x**2, axis=1)[:, jnp.newaxis] + jnp.sum(self.y**2, axis=1)[jnp.newaxis, :] - 2 * jnp.dot(x, self.y.T), epsilon=epsilon), a=a, b=self.b, tau_a=0.8, tau_b=0.87, threshold=1e-4, lse_mode=lse_mode, implicit_differentiation=implicit) return out.reg_ot_cost
def potential(a, x, implicit, d): out = sinkhorn.sinkhorn(geometry.Geometry( cost_matrix=jnp.sum(x**2, axis=1)[:, jnp.newaxis] + jnp.sum(y**2, axis=1)[jnp.newaxis, :] - 2 * jnp.dot(x, y.T), epsilon=epsilon), a=a, b=b, tau_a=tau_a, tau_b=tau_b, lse_mode=lse_mode, threshold=1e-4, implicit_differentiation=implicit, inner_iterations=2) return jnp.sum(out.f * d)
def fit(self, X, Y): '''Parameters -------------- X: (n_samples, n_features) nd array source data Y: (n_samples, n_features) nd array target data ''' from ott.geometry import geometry from ott.tools import transport n = len(X.T) cost_matrix = cdist(X.T, Y.T, metric=self.metric) geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=self.reg) P = transport.Transport(geom, max_iterations=self.max_iter, threshold=self.tol) P.solve() self.R = np.asarray(P.matrix * n) return self
def test_apply_cost_and_kernel(self): """Test consistency of cost/kernel apply to vec.""" n, m, p, b = 5, 8, 10, 7 keys = jax.random.split(self.rng, 5) x = jax.random.normal(keys[0], (n, p)) y = jax.random.normal(keys[1], (m, p)) + 1 cost = jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1) vec0 = jax.random.normal(keys[2], (n, b)) vec1 = jax.random.normal(keys[3], (m, b)) geom = pointcloud.PointCloud(x, y, power=2, online=True) prod0_online = geom.apply_cost(vec0, axis=0) prod1_online = geom.apply_cost(vec1, axis=1) geom = pointcloud.PointCloud(x, y, power=2, online=False) prod0 = geom.apply_cost(vec0, axis=0) prod1 = geom.apply_cost(vec1, axis=1) geom = geometry.Geometry(cost) prod0_geom = geom.apply_cost(vec0, axis=0) prod1_geom = geom.apply_cost(vec1, axis=1) self.assertAllClose(prod0_online, prod0, rtol=1e-03, atol=1e-02) self.assertAllClose(prod1_online, prod1, rtol=1e-03, atol=1e-02) self.assertAllClose(prod0_geom, prod0, rtol=1e-03, atol=1e-02) self.assertAllClose(prod1_geom, prod1, rtol=1e-03, atol=1e-02) geom = pointcloud.PointCloud(x, y, power=1, online=True) prod0_online = geom.apply_cost(vec0, axis=0) prod1_online = geom.apply_cost(vec1, axis=1) geom = pointcloud.PointCloud(x, y, power=1, online=False) prod0 = geom.apply_cost(vec0, axis=0) prod1 = geom.apply_cost(vec1, axis=1) self.assertAllClose(prod0_online, prod0, rtol=1e-03, atol=1e-02) self.assertAllClose(prod1_online, prod1, rtol=1e-03, atol=1e-02) geom = pointcloud.PointCloud(x, y, power=2, online=True) prod0_online = geom.apply_kernel(vec0, 1., axis=0) prod1_online = geom.apply_kernel(vec1, 1., axis=1) geom = pointcloud.PointCloud(x, y, power=2, online=False) prod0 = geom.apply_kernel(vec0, 1., axis=0) prod1 = geom.apply_kernel(vec1, 1., axis=1) self.assertAllClose(prod0_online, prod0, rtol=1e-03, atol=1e-02) self.assertAllClose(prod1_online, prod1, rtol=1e-03, atol=1e-02)
def setUp(self): super().setUp() self.rng = jax.random.PRNGKey(0) self.dim = 3 self.n = 10 self.m = 11 self.rng, *rngs = jax.random.split(self.rng, 10) self.rngs = rngs self.x = jax.random.uniform(rngs[0], (self.n, self.dim)) self.y = jax.random.uniform(rngs[1], (self.m, self.dim)) a = jax.random.uniform(rngs[2], (self.n, )) + .1 b = jax.random.uniform(rngs[3], (self.m, )) + .1 self.a = a / jnp.sum(a) self.b = b / jnp.sum(b) self.epsilon = 0.05 self.geometry = geometry.Geometry( cost_matrix=(jnp.sum(self.x**2, axis=1)[:, jnp.newaxis] + jnp.sum(self.y**2, axis=1)[jnp.newaxis, :] - 2 * jnp.dot(self.x, self.y.T)), epsilon=self.epsilon)
def loss_fn(cm): a = jnp.ones(cm.shape[0]) / cm.shape[0] b = jnp.ones(cm.shape[1]) / cm.shape[1] geom = geometry.Geometry(cm, epsilon=0.5) out = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode) return out.reg_ot_cost, (geom, out.f, out.g)
def _update_geometry_gw(geom: geometry.Geometry, geom_x: geometry.Geometry, geom_y: geometry.Geometry, f: jnp.ndarray, g: jnp.ndarray, loss: GWLoss, **kwargs) -> geometry.Geometry: """Updates the geometry object for GW by updating the cost matrix. The cost matrix equation follows Equation 6, Proposition 1 of http://proceedings.mlr.press/v48/peyre16.pdf. Let :math:`p` [num_a,] be the marginal of the transport matrix for samples from geom_x and :math:`q` [num_b,] be the marginal of the transport matrix for samples from geom_y. Let :math:`T` [num_a, num_b] be the transport matrix. The cost matrix equation can be written as: cost_matrix = marginal_dep_term + left_x(cost_x) :math:`T` right_y(cost_y):math:`^T` Args: geom: a Geometry object carrying the cost matrix of Gromov Wasserstein. geom_x: a Geometry object for the first view. geom_y: a second Geometry object for the second view. f: jnp.ndarray<float>[num_a,], potentials. g: jnp.ndarray<float>[num_b,], potentials. loss: a GWLossFn object. **kwargs: additional kwargs for epsilon. Returns: A Geometry object for Gromov-Wasserstein. """ def apply_cost_fn(geom): condition = is_sqeuclidean(geom) and isinstance(loss, GWSqEuclLoss) return geom.vec_apply_cost if condition else geom.apply_cost def is_sqeuclidean(geom): return (isinstance(geom, pointcloud.PointCloud) and geom.power == 2.0 and isinstance(geom._cost_fn, costs.Euclidean)) def is_online(geom): return isinstance(geom, pointcloud.PointCloud) and geom._online # Computes tmp = cost_matrix_x * transport if is_online(geom_x) or is_sqeuclidean(geom_x): transport = geom.transport_from_potentials(f, g) tmp = apply_cost_fn(geom_x)(transport, axis=1, fn=loss.left_x) else: tmp = geom.apply_transport_from_potentials(f, g, loss.left_x( geom_x.cost_matrix), axis=0) # Computes cost_matrix marginal_x = geom.marginal_from_potentials(f, g, axis=1) marginal_y = geom.marginal_from_potentials(f, g, axis=0) marginal_dep_term = _marginal_dependent_cost(marginal_x, marginal_y, geom_x, geom_y, loss) cost_matrix = marginal_dep_term - apply_cost_fn(geom_y)( tmp.T, axis=1, fn=loss.right_y).T return geometry.Geometry(cost_matrix=cost_matrix, epsilon=geom._epsilon, **kwargs)