Ejemplo n.º 1
0
def _init_geometry_gw(geom_x: geometry.Geometry, geom_y: geometry.Geometry,
                      a: jnp.ndarray, b: jnp.ndarray,
                      epsilon: Union[epsilon_scheduler.Epsilon, float],
                      loss: GWLoss, **kwargs) -> geometry.Geometry:
    """Initialises the cost matrix for the geometry object for GW.

  The equation follows Equation 6, Proposition 1 of
  http://proceedings.mlr.press/v48/peyre16.pdf.

  Args:
    geom_x: a Geometry object for the first view.
    geom_y: a second Geometry object for the second view.
    a: jnp.ndarray<float>[num_a,], weights.
    b: jnp.ndarray<float>[num_b,], weights.
    epsilon: a regularization parameter or a epsilon_scheduler.Epsilon object.
    loss: a GWLossFn object.
    **kwargs: additional kwargs to epsilon.

  Returns:
    A Geometry object for Gromov-Wasserstein.
  """
    # Initialization of the transport matrix in the balanced case, following
    # http://proceedings.mlr.press/v48/peyre16.pdf
    ab = a[:, None] * b[None, :]
    marginal_x = ab.sum(1)
    marginal_y = ab.sum(0)
    marginal_dep_term = _marginal_dependent_cost(marginal_x, marginal_y,
                                                 geom_x, geom_y, loss)

    tmp = geom_x.apply_cost(ab, axis=1, fn=loss.left_x)
    cost_matrix = marginal_dep_term - geom_y.apply_cost(
        tmp.T, axis=1, fn=loss.right_y).T
    return geometry.Geometry(cost_matrix=cost_matrix,
                             epsilon=epsilon,
                             **kwargs)
Ejemplo n.º 2
0
 def test_geom_vs_point_cloud(self):
     """Two point clouds vs. simple cost_matrix execution of sinkorn."""
     geom = pointcloud.PointCloud(self.x, self.y)
     geom_2 = geometry.Geometry(geom.cost_matrix)
     f = sinkhorn.sinkhorn(geom, a=self.a, b=self.b).f
     f_2 = sinkhorn.sinkhorn(geom_2, a=self.a, b=self.b).f
     self.assertAllClose(f, f_2)
def _ot(X, Y, metric='euclidean', reg=0.1):
    """
    Compute the optimal coupling between X and Y with entropic regularization
    using a OTT as a backend for acceleration.

    Parameters
    ----------
    metric : str(optional)
        metric used to create transport cost matrix, \
        see full list in scipy.spatial.distance.cdist doc
    reg : int (optional)
        level of entropic regularization
    Attributes
    ----------
    R : scipy.sparse.csr_matrix
        Mixing matrix containing the optimal permutation
    """
    n = len(X.T)
    cost_matrix = cdist(X.T, Y.T, metric=metric)
    geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=reg)
    P = transport.Transport(geom, max_iterations=1000, threshold=1e-3)
    P.solve()
    R = np.asarray(P.matrix * n)

    return R, R.dot(X.T).T
Ejemplo n.º 4
0
 def loss_g(a, x, implicit=True):
     out = sinkhorn.sinkhorn(geometry.Geometry(
         cost_matrix=jnp.sum(x**2, axis=1)[:, jnp.newaxis] +
         jnp.sum(self.y**2, axis=1)[jnp.newaxis, :] -
         2 * jnp.dot(x, self.y.T),
         epsilon=epsilon),
                             a=a,
                             b=self.b,
                             tau_a=0.8,
                             tau_b=0.87,
                             threshold=1e-4,
                             lse_mode=lse_mode,
                             implicit_differentiation=implicit)
     return out.reg_ot_cost
Ejemplo n.º 5
0
 def potential(a, x, implicit, d):
     out = sinkhorn.sinkhorn(geometry.Geometry(
         cost_matrix=jnp.sum(x**2, axis=1)[:, jnp.newaxis] +
         jnp.sum(y**2, axis=1)[jnp.newaxis, :] -
         2 * jnp.dot(x, y.T),
         epsilon=epsilon),
                             a=a,
                             b=b,
                             tau_a=tau_a,
                             tau_b=tau_b,
                             lse_mode=lse_mode,
                             threshold=1e-4,
                             implicit_differentiation=implicit,
                             inner_iterations=2)
     return jnp.sum(out.f * d)
Ejemplo n.º 6
0
 def fit(self, X, Y):
     '''Parameters
     --------------
     X: (n_samples, n_features) nd array
         source data
     Y: (n_samples, n_features) nd array
         target data
         '''
     from ott.geometry import geometry
     from ott.tools import transport
     n = len(X.T)
     cost_matrix = cdist(X.T, Y.T, metric=self.metric)
     geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=self.reg)
     P = transport.Transport(geom,
                             max_iterations=self.max_iter,
                             threshold=self.tol)
     P.solve()
     self.R = np.asarray(P.matrix * n)
     return self
  def test_apply_cost_and_kernel(self):
    """Test consistency of cost/kernel apply to vec."""
    n, m, p, b = 5, 8, 10, 7
    keys = jax.random.split(self.rng, 5)
    x = jax.random.normal(keys[0], (n, p))
    y = jax.random.normal(keys[1], (m, p)) + 1
    cost = jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1)
    vec0 = jax.random.normal(keys[2], (n, b))
    vec1 = jax.random.normal(keys[3], (m, b))

    geom = pointcloud.PointCloud(x, y, power=2, online=True)
    prod0_online = geom.apply_cost(vec0, axis=0)
    prod1_online = geom.apply_cost(vec1, axis=1)
    geom = pointcloud.PointCloud(x, y, power=2, online=False)
    prod0 = geom.apply_cost(vec0, axis=0)
    prod1 = geom.apply_cost(vec1, axis=1)
    geom = geometry.Geometry(cost)
    prod0_geom = geom.apply_cost(vec0, axis=0)
    prod1_geom = geom.apply_cost(vec1, axis=1)
    self.assertAllClose(prod0_online, prod0, rtol=1e-03, atol=1e-02)
    self.assertAllClose(prod1_online, prod1, rtol=1e-03, atol=1e-02)
    self.assertAllClose(prod0_geom, prod0, rtol=1e-03, atol=1e-02)
    self.assertAllClose(prod1_geom, prod1, rtol=1e-03, atol=1e-02)

    geom = pointcloud.PointCloud(x, y, power=1, online=True)
    prod0_online = geom.apply_cost(vec0, axis=0)
    prod1_online = geom.apply_cost(vec1, axis=1)
    geom = pointcloud.PointCloud(x, y, power=1, online=False)
    prod0 = geom.apply_cost(vec0, axis=0)
    prod1 = geom.apply_cost(vec1, axis=1)
    self.assertAllClose(prod0_online, prod0, rtol=1e-03, atol=1e-02)
    self.assertAllClose(prod1_online, prod1, rtol=1e-03, atol=1e-02)

    geom = pointcloud.PointCloud(x, y, power=2, online=True)
    prod0_online = geom.apply_kernel(vec0, 1., axis=0)
    prod1_online = geom.apply_kernel(vec1, 1., axis=1)
    geom = pointcloud.PointCloud(x, y, power=2, online=False)
    prod0 = geom.apply_kernel(vec0, 1., axis=0)
    prod1 = geom.apply_kernel(vec1, 1., axis=1)
    self.assertAllClose(prod0_online, prod0, rtol=1e-03, atol=1e-02)
    self.assertAllClose(prod1_online, prod1, rtol=1e-03, atol=1e-02)
Ejemplo n.º 8
0
    def setUp(self):
        super().setUp()
        self.rng = jax.random.PRNGKey(0)
        self.dim = 3
        self.n = 10
        self.m = 11
        self.rng, *rngs = jax.random.split(self.rng, 10)
        self.rngs = rngs
        self.x = jax.random.uniform(rngs[0], (self.n, self.dim))
        self.y = jax.random.uniform(rngs[1], (self.m, self.dim))
        a = jax.random.uniform(rngs[2], (self.n, )) + .1
        b = jax.random.uniform(rngs[3], (self.m, )) + .1

        self.a = a / jnp.sum(a)
        self.b = b / jnp.sum(b)
        self.epsilon = 0.05
        self.geometry = geometry.Geometry(
            cost_matrix=(jnp.sum(self.x**2, axis=1)[:, jnp.newaxis] +
                         jnp.sum(self.y**2, axis=1)[jnp.newaxis, :] -
                         2 * jnp.dot(self.x, self.y.T)),
            epsilon=self.epsilon)
Ejemplo n.º 9
0
 def loss_fn(cm):
   a = jnp.ones(cm.shape[0]) / cm.shape[0]
   b = jnp.ones(cm.shape[1]) / cm.shape[1]
   geom = geometry.Geometry(cm, epsilon=0.5)
   out = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode)
   return out.reg_ot_cost, (geom, out.f, out.g)
Ejemplo n.º 10
0
def _update_geometry_gw(geom: geometry.Geometry, geom_x: geometry.Geometry,
                        geom_y: geometry.Geometry, f: jnp.ndarray,
                        g: jnp.ndarray, loss: GWLoss,
                        **kwargs) -> geometry.Geometry:
    """Updates the geometry object for GW by updating the cost matrix.

  The cost matrix equation follows Equation 6, Proposition 1 of
  http://proceedings.mlr.press/v48/peyre16.pdf.

  Let :math:`p` [num_a,] be the marginal of the transport matrix for samples
  from geom_x and :math:`q` [num_b,] be the marginal of the transport matrix for
  samples from geom_y. Let :math:`T` [num_a, num_b] be the transport matrix.
  The cost matrix equation can be written as:

  cost_matrix = marginal_dep_term
              + left_x(cost_x) :math:`T` right_y(cost_y):math:`^T`

  Args:
    geom: a Geometry object carrying the cost matrix of Gromov Wasserstein.
    geom_x: a Geometry object for the first view.
    geom_y: a second Geometry object for the second view.
    f: jnp.ndarray<float>[num_a,], potentials.
    g: jnp.ndarray<float>[num_b,], potentials.
    loss: a GWLossFn object.
    **kwargs: additional kwargs for epsilon.

  Returns:
    A Geometry object for Gromov-Wasserstein.
  """
    def apply_cost_fn(geom):
        condition = is_sqeuclidean(geom) and isinstance(loss, GWSqEuclLoss)
        return geom.vec_apply_cost if condition else geom.apply_cost

    def is_sqeuclidean(geom):
        return (isinstance(geom, pointcloud.PointCloud) and geom.power == 2.0
                and isinstance(geom._cost_fn, costs.Euclidean))

    def is_online(geom):
        return isinstance(geom, pointcloud.PointCloud) and geom._online

    # Computes tmp = cost_matrix_x * transport
    if is_online(geom_x) or is_sqeuclidean(geom_x):
        transport = geom.transport_from_potentials(f, g)
        tmp = apply_cost_fn(geom_x)(transport, axis=1, fn=loss.left_x)
    else:
        tmp = geom.apply_transport_from_potentials(f,
                                                   g,
                                                   loss.left_x(
                                                       geom_x.cost_matrix),
                                                   axis=0)

    # Computes cost_matrix
    marginal_x = geom.marginal_from_potentials(f, g, axis=1)
    marginal_y = geom.marginal_from_potentials(f, g, axis=0)
    marginal_dep_term = _marginal_dependent_cost(marginal_x, marginal_y,
                                                 geom_x, geom_y, loss)
    cost_matrix = marginal_dep_term - apply_cost_fn(geom_y)(
        tmp.T, axis=1, fn=loss.right_y).T
    return geometry.Geometry(cost_matrix=cost_matrix,
                             epsilon=geom._epsilon,
                             **kwargs)