示例#1
0
    def test_transport_wrong_init(self):
        rngs = jax.random.split(self.rng, 2)
        num_a, num_b = 23, 48
        x = jax.random.uniform(rngs[0], (num_a, 4))
        y = jax.random.uniform(rngs[1], (num_b, 4))
        geom = pointcloud.PointCloud(x, y, epsilon=1e-3, online=True)
        with self.assertRaises(ValueError):
            transport.Transport(geom, x, threshold=1e-3)

        with self.assertRaises(AttributeError):
            transport.Transport(x, y, x, y, threshold=1e-3)
示例#2
0
def transport_for_sort(inputs: jnp.ndarray, weights: jnp.ndarray,
                       target_weights: jnp.ndarray, kwargs) -> jnp.ndarray:
    """Runs sinkhorn on a fixed increasing target.

  Args:
    inputs: jnp.ndarray[num_points]. Must be one dimensional.
    weights: jnp.ndarray[num_points]. The weights 'a' for the inputs.
    target_weights: jnp.ndarray[num_targets]: the weights of the targets. It may
      be of a different size than the weights.
    kwargs: a dictionary holding the sinkhorn keyword arguments and the
      pointcloud argument.

  Returns:
    A jnp.ndarray<float> representing the transport matrix of the inputs onto
    the underlying sorted target.
  """
    shape = inputs.shape
    if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1):
        raise ValueError(
            'Shape ({shape}) not supported. The input should be one-dimensional.'
        )

    x = jnp.expand_dims(jnp.squeeze(inputs), axis=1)
    x = jax.nn.sigmoid((x - jnp.mean(x)) / (jnp.std(x) + 1e-10))
    a = jnp.squeeze(weights)
    b = jnp.squeeze(target_weights)
    num_targets = b.shape[0]
    y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis]
    return transport.Transport(x, y, a=a, b=b, **kwargs)
def _ot(X, Y, metric='euclidean', reg=0.1):
    """
    Compute the optimal coupling between X and Y with entropic regularization
    using a OTT as a backend for acceleration.

    Parameters
    ----------
    metric : str(optional)
        metric used to create transport cost matrix, \
        see full list in scipy.spatial.distance.cdist doc
    reg : int (optional)
        level of entropic regularization
    Attributes
    ----------
    R : scipy.sparse.csr_matrix
        Mixing matrix containing the optimal permutation
    """
    n = len(X.T)
    cost_matrix = cdist(X.T, Y.T, metric=metric)
    geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=reg)
    P = transport.Transport(geom, max_iterations=1000, threshold=1e-3)
    P.solve()
    R = np.asarray(P.matrix * n)

    return R, R.dot(X.T).T
示例#4
0
def transport_for_sort(inputs: jnp.ndarray, weights: jnp.ndarray,
                       target_weights: jnp.ndarray, kwargs) -> jnp.ndarray:
    """Solves reg. OT, from inputs to a weighted family of increasing values.

  Args:
    inputs: jnp.ndarray[num_points]. Must be one dimensional.
    weights: jnp.ndarray[num_points]. Weight vector `a` for input values.
    target_weights: jnp.ndarray[num_targets]: Weight vector of the target
      measure. It may be of different size than `weights`.
    kwargs: a dictionary holding the keyword arguments for `sinkhorn` and
      `PointCloud`.

  Returns:
    A jnp.ndarray<float> representing the transport matrix of the inputs onto
    the underlying sorted target.
  """
    shape = inputs.shape
    if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1):
        raise ValueError(
            'Shape ({shape}) not supported. The input should be one-dimensional.'
        )

    x = jnp.expand_dims(jnp.squeeze(inputs), axis=1)
    x = jax.nn.sigmoid((x - jnp.mean(x)) / (jnp.std(x) + 1e-10))
    a = jnp.squeeze(weights)
    b = jnp.squeeze(target_weights)
    num_targets = b.shape[0]
    y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis]
    return transport.Transport(x, y, a=a, b=b, **kwargs)
示例#5
0
 def test_transport_from_point(self):
     rngs = jax.random.split(self.rng, 2)
     num_a, num_b = 23, 48
     x = jax.random.uniform(rngs[0], (num_a, 4))
     y = jax.random.uniform(rngs[1], (num_b, 4))
     ot = transport.Transport(x, y, epsilon=1e-2, threshold=1e-2)
     self.assertEqual(ot.matrix.shape, (num_a, num_b))
     self.assertAllClose(jnp.sum(ot.matrix, axis=1), ot.a, atol=1e-3)
     self.assertAllClose(jnp.sum(ot.matrix, axis=0), ot.b, atol=1e-3)
示例#6
0
 def test_transport_from_geom(self):
     rngs = jax.random.split(self.rng, 3)
     num_a, num_b = 23, 48
     x = jax.random.uniform(rngs[0], (num_a, 4))
     y = jax.random.uniform(rngs[1], (num_b, 4))
     geom = pointcloud.PointCloud(x, y, epsilon=1e-3, online=True)
     b = jax.random.uniform(rngs[2], (num_b, ))
     b /= jnp.sum(b)
     ot = transport.Transport(geom, b=b, threshold=1e-3)
     self.assertEqual(ot.matrix.shape, (num_a, num_b))
     self.assertAllClose(jnp.sum(ot.matrix, axis=1), ot.a, atol=1e-3)
     self.assertAllClose(jnp.sum(ot.matrix, axis=0), ot.b, atol=1e-3)
示例#7
0
 def loss(a, x, implicit):
     out = transport.Transport(x,
                               y,
                               epsilon=epsilon,
                               a=a,
                               b=b,
                               tau_a=tau_a,
                               tau_b=tau_b,
                               lse_mode=lse_mode,
                               implicit_differentiation=implicit,
                               use_danskin=False,
                               linear_solve_kwargs=linear_solve_kwargs,
                               threshold=1e-5)
     return out.reg_ot_cost
示例#8
0
 def fit(self, X, Y):
     '''Parameters
     --------------
     X: (n_samples, n_features) nd array
         source data
     Y: (n_samples, n_features) nd array
         target data
         '''
     from ott.geometry import geometry
     from ott.tools import transport
     n = len(X.T)
     cost_matrix = cdist(X.T, Y.T, metric=self.metric)
     geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=self.reg)
     P = transport.Transport(geom,
                             max_iterations=self.max_iter,
                             threshold=self.tol)
     P.solve()
     self.R = np.asarray(P.matrix * n)
     return self