def test_transport_wrong_init(self): rngs = jax.random.split(self.rng, 2) num_a, num_b = 23, 48 x = jax.random.uniform(rngs[0], (num_a, 4)) y = jax.random.uniform(rngs[1], (num_b, 4)) geom = pointcloud.PointCloud(x, y, epsilon=1e-3, online=True) with self.assertRaises(ValueError): transport.Transport(geom, x, threshold=1e-3) with self.assertRaises(AttributeError): transport.Transport(x, y, x, y, threshold=1e-3)
def transport_for_sort(inputs: jnp.ndarray, weights: jnp.ndarray, target_weights: jnp.ndarray, kwargs) -> jnp.ndarray: """Runs sinkhorn on a fixed increasing target. Args: inputs: jnp.ndarray[num_points]. Must be one dimensional. weights: jnp.ndarray[num_points]. The weights 'a' for the inputs. target_weights: jnp.ndarray[num_targets]: the weights of the targets. It may be of a different size than the weights. kwargs: a dictionary holding the sinkhorn keyword arguments and the pointcloud argument. Returns: A jnp.ndarray<float> representing the transport matrix of the inputs onto the underlying sorted target. """ shape = inputs.shape if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1): raise ValueError( 'Shape ({shape}) not supported. The input should be one-dimensional.' ) x = jnp.expand_dims(jnp.squeeze(inputs), axis=1) x = jax.nn.sigmoid((x - jnp.mean(x)) / (jnp.std(x) + 1e-10)) a = jnp.squeeze(weights) b = jnp.squeeze(target_weights) num_targets = b.shape[0] y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis] return transport.Transport(x, y, a=a, b=b, **kwargs)
def _ot(X, Y, metric='euclidean', reg=0.1): """ Compute the optimal coupling between X and Y with entropic regularization using a OTT as a backend for acceleration. Parameters ---------- metric : str(optional) metric used to create transport cost matrix, \ see full list in scipy.spatial.distance.cdist doc reg : int (optional) level of entropic regularization Attributes ---------- R : scipy.sparse.csr_matrix Mixing matrix containing the optimal permutation """ n = len(X.T) cost_matrix = cdist(X.T, Y.T, metric=metric) geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=reg) P = transport.Transport(geom, max_iterations=1000, threshold=1e-3) P.solve() R = np.asarray(P.matrix * n) return R, R.dot(X.T).T
def transport_for_sort(inputs: jnp.ndarray, weights: jnp.ndarray, target_weights: jnp.ndarray, kwargs) -> jnp.ndarray: """Solves reg. OT, from inputs to a weighted family of increasing values. Args: inputs: jnp.ndarray[num_points]. Must be one dimensional. weights: jnp.ndarray[num_points]. Weight vector `a` for input values. target_weights: jnp.ndarray[num_targets]: Weight vector of the target measure. It may be of different size than `weights`. kwargs: a dictionary holding the keyword arguments for `sinkhorn` and `PointCloud`. Returns: A jnp.ndarray<float> representing the transport matrix of the inputs onto the underlying sorted target. """ shape = inputs.shape if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1): raise ValueError( 'Shape ({shape}) not supported. The input should be one-dimensional.' ) x = jnp.expand_dims(jnp.squeeze(inputs), axis=1) x = jax.nn.sigmoid((x - jnp.mean(x)) / (jnp.std(x) + 1e-10)) a = jnp.squeeze(weights) b = jnp.squeeze(target_weights) num_targets = b.shape[0] y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis] return transport.Transport(x, y, a=a, b=b, **kwargs)
def test_transport_from_point(self): rngs = jax.random.split(self.rng, 2) num_a, num_b = 23, 48 x = jax.random.uniform(rngs[0], (num_a, 4)) y = jax.random.uniform(rngs[1], (num_b, 4)) ot = transport.Transport(x, y, epsilon=1e-2, threshold=1e-2) self.assertEqual(ot.matrix.shape, (num_a, num_b)) self.assertAllClose(jnp.sum(ot.matrix, axis=1), ot.a, atol=1e-3) self.assertAllClose(jnp.sum(ot.matrix, axis=0), ot.b, atol=1e-3)
def test_transport_from_geom(self): rngs = jax.random.split(self.rng, 3) num_a, num_b = 23, 48 x = jax.random.uniform(rngs[0], (num_a, 4)) y = jax.random.uniform(rngs[1], (num_b, 4)) geom = pointcloud.PointCloud(x, y, epsilon=1e-3, online=True) b = jax.random.uniform(rngs[2], (num_b, )) b /= jnp.sum(b) ot = transport.Transport(geom, b=b, threshold=1e-3) self.assertEqual(ot.matrix.shape, (num_a, num_b)) self.assertAllClose(jnp.sum(ot.matrix, axis=1), ot.a, atol=1e-3) self.assertAllClose(jnp.sum(ot.matrix, axis=0), ot.b, atol=1e-3)
def loss(a, x, implicit): out = transport.Transport(x, y, epsilon=epsilon, a=a, b=b, tau_a=tau_a, tau_b=tau_b, lse_mode=lse_mode, implicit_differentiation=implicit, use_danskin=False, linear_solve_kwargs=linear_solve_kwargs, threshold=1e-5) return out.reg_ot_cost
def fit(self, X, Y): '''Parameters -------------- X: (n_samples, n_features) nd array source data Y: (n_samples, n_features) nd array target data ''' from ott.geometry import geometry from ott.tools import transport n = len(X.T) cost_matrix = cdist(X.T, Y.T, metric=self.metric) geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=self.reg) P = transport.Transport(geom, max_iterations=self.max_iter, threshold=self.tol) P.solve() self.R = np.asarray(P.matrix * n) return self