Пример #1
0
 def test_geom_vs_point_cloud(self):
     """Two point clouds vs. simple cost_matrix execution of sinkorn."""
     geom = pointcloud.PointCloud(self.x, self.y)
     geom_2 = geometry.Geometry(geom.cost_matrix)
     f = sinkhorn.sinkhorn(geom, a=self.a, b=self.b).f
     f_2 = sinkhorn.sinkhorn(geom_2, a=self.a, b=self.b).f
     self.assertAllClose(f, f_2)
Пример #2
0
    def test_online_vs_batch_euclidean_point_cloud(self, lse_mode):
        """Comparing online vs batch geometry."""
        threshold = 1e-3
        eps = 0.1
        online_geom = pointcloud.PointCloud(self.x,
                                            self.y,
                                            epsilon=eps,
                                            online=True)
        online_geom_euc = pointcloud.PointCloud(self.x,
                                                self.y,
                                                cost_fn=costs.Euclidean(),
                                                epsilon=eps,
                                                online=True)

        batch_geom = pointcloud.PointCloud(self.x, self.y, epsilon=eps)
        batch_geom_euc = pointcloud.PointCloud(self.x,
                                               self.y,
                                               cost_fn=costs.Euclidean(),
                                               epsilon=eps)

        out_online = sinkhorn.sinkhorn(online_geom,
                                       a=self.a,
                                       b=self.b,
                                       threshold=threshold,
                                       lse_mode=lse_mode)
        out_batch = sinkhorn.sinkhorn(batch_geom,
                                      a=self.a,
                                      b=self.b,
                                      threshold=threshold,
                                      lse_mode=lse_mode)
        out_online_euc = sinkhorn.sinkhorn(online_geom_euc,
                                           a=self.a,
                                           b=self.b,
                                           threshold=threshold,
                                           lse_mode=lse_mode)
        out_batch_euc = sinkhorn.sinkhorn(batch_geom_euc,
                                          a=self.a,
                                          b=self.b,
                                          threshold=threshold,
                                          lse_mode=lse_mode)

        # Checks regularized transport costs match.
        self.assertAllClose(out_online.reg_ot_cost, out_batch.reg_ot_cost)
        # check regularized transport matrices match
        self.assertAllClose(
            online_geom.transport_from_potentials(out_online.f, out_online.g),
            batch_geom.transport_from_potentials(out_batch.f, out_batch.g))

        self.assertAllClose(
            online_geom_euc.transport_from_potentials(out_online_euc.f,
                                                      out_online_euc.g),
            batch_geom_euc.transport_from_potentials(out_batch_euc.f,
                                                     out_batch_euc.g))

        self.assertAllClose(
            batch_geom.transport_from_potentials(out_batch.f, out_batch.g),
            batch_geom_euc.transport_from_potentials(out_batch_euc.f,
                                                     out_batch_euc.g))
Пример #3
0
def sinkhorn_for_sort(inputs: jnp.ndarray, weights: jnp.ndarray,
                      target_weights: jnp.ndarray, sinkhorn_kw,
                      pointcloud_kw) -> jnp.ndarray:
    """Runs sinkhorn on a fixed increasing target.

  Args:
    inputs: jnp.ndarray[num_points]. Must be one dimensional.
    weights: jnp.ndarray[num_points]. The weights 'a' for the inputs.
    target_weights: jnp.ndarray[num_targets]: the weights of the targets. It may
      be of a different size than the weights.
    sinkhorn_kw: a dictionary holding the sinkhorn keyword arguments. See
      sinkhorn.py for more details.
    pointcloud_kw: a dictionary holding the keyword arguments of the
      PointCloud class. See pointcloud.py for more details.

  Returns:
    A jnp.ndarray<float> representing the transport matrix of the inputs onto
    the underlying sorted target.
  """
    shape = inputs.shape
    if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1):
        raise ValueError(
            "Shape ({shape}) not supported. The input should be one-dimensional."
        )

    x = jnp.expand_dims(jnp.squeeze(inputs), axis=1)
    x = jax.nn.sigmoid((x - jnp.mean(x)) / (jnp.std(x) + 1e-10))
    a = jnp.squeeze(weights)
    b = jnp.squeeze(target_weights)
    num_targets = b.shape[0]
    y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis]
    geom = pointcloud.PointCloud(x, y, **pointcloud_kw)
    res = sinkhorn.sinkhorn(geom, a, b, **sinkhorn_kw)
    return geom.transport_from_potentials(res.f, res.g)
 def reg_ot(x):
     geom = grid.Grid(x=x, epsilon=1.0)
     return sinkhorn.sinkhorn(geom,
                              a=a,
                              b=b,
                              threshold=0.1,
                              lse_mode=lse_mode).reg_ot_cost
Пример #5
0
 def solve(self):
   """Runs the sinkhorn algorithm to solve the transport problem."""
   out = sinkhorn.sinkhorn(self.geom, self.a, self.b, **self._kwargs)
   # TODO(oliviert): figure out how to warn the user if no convergence.
   # So far we always set the values, even if not converged.
   # TODO(oliviert, cuturi): handles cases where it has not converged.
   self._f = out.f
   self._g = out.g
   self.reg_ot_cost = out.reg_ot_cost
 def loss_fn(x, y):
     geom = pointcloud.PointCloud(x, y, epsilon=epsilon)
     f, g, regularized_transport_cost, _, _ = sinkhorn.sinkhorn(
         geom,
         a,
         b,
         lse_mode=lse_mode,
         implicit_differentiation=implicit_differentiation)
     return regularized_transport_cost, (geom, f, g)
Пример #7
0
 def test_online_euclidean_point_cloud(self, lse_mode):
   """Testing the online way to handle geometry."""
   threshold = 1e-3
   geom = pointcloud.PointCloud(
       self.x, self.y, epsilon=0.1, online=True)
   errors = sinkhorn.sinkhorn(
       geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode).errors
   err = errors[errors > -1][-1]
   self.assertGreater(threshold, err)
Пример #8
0
 def loss_fn(x, y):
     geom = pointcloud.PointCloud(x, y, epsilon=0.01)
     f, g, regularized_transport_cost, _, _ = sinkhorn.sinkhorn(
         geom,
         a,
         b,
         momentum_strategy=momentum_strategy,
         lse_mode=lse_mode)
     return regularized_transport_cost, (geom, f, g)
Пример #9
0
def _sinkhorn_divergence(geometry_xy: geometry.Geometry,
                         geometry_xx: geometry.Geometry,
                         geometry_yy: Optional[geometry.Geometry],
                         a: jnp.ndarray, b: jnp.ndarray, **kwargs):
    """Computes the (unbalanced) sinkhorn divergence for the wrapper function.

    This definition includes a correction depending on the total masses of each
    measure, as defined in https://arxiv.org/pdf/1910.12958.pdf (15).

  Args:
    geometry_xy: a Cost object able to apply kernels with a certain epsilon,
    between the views X and Y.
    geometry_xx: a Cost object able to apply kernels with a certain epsilon,
    between elements of the view X.
    geometry_yy: a Cost object able to apply kernels with a certain epsilon,
    between elements of the view Y.
    a: jnp.ndarray<float>[n]: the weight of each input point. The sum of
     all elements of b must match that of a to converge.
    b: jnp.ndarray<float>[m]: the weight of each target point. The sum of
     all elements of b must match that of a to converge.
    **kwargs: Arguments to sinkhorn.
  Returns:
    SinkhornDivergenceOutput named tuple.
  """
    # Replaces parallel/momentum arguments in symmetric case.
    kwargs_symmetric = kwargs.copy()
    kwargs_symmetric.update(parallel_dual_updates=True, momentum_strategy=0.5)

    out_xy = sinkhorn.sinkhorn(geometry_xy, a, b, **kwargs)
    out_xx = sinkhorn.sinkhorn(geometry_xx, a, a, **kwargs_symmetric)
    if geometry_yy is None:
        out_yy = sinkhorn.SinkhornOutput(None, None, 0, None, None)
    else:
        out_yy = sinkhorn.sinkhorn(geometry_yy, b, b, **kwargs_symmetric)

    div = (out_xy.reg_ot_cost - 0.5 *
           (out_xx.reg_ot_cost + out_yy.reg_ot_cost) +
           0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b))**2)
    out = (out_xy, out_xx, out_yy)
    return SinkhornDivergenceOutput(div, tuple([s.f, s.g] for s in out),
                                    (geometry_xy, geometry_xx, geometry_yy),
                                    tuple(s.errors for s in out),
                                    tuple(s.converged for s in out))
Пример #10
0
 def test_grid_vs_euclidean(self, lse_mode):
   grid_size = (5, 6, 7)
   keys = jax.random.split(self.rng, 2)
   a = jax.random.uniform(keys[0], grid_size)
   b = jax.random.uniform(keys[1], grid_size)
   a = a.ravel() / jnp.sum(a)
   b = b.ravel() / jnp.sum(b)
   epsilon = 0.1
   geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon)
   x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]]
   xyz = jnp.stack([
       jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
       jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
       jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
   ]).transpose()
   geometry_mat = pointcloud.PointCloud(xyz, xyz, epsilon=epsilon)
   out_mat = sinkhorn.sinkhorn(geometry_mat, a=a, b=b, lse_mode=lse_mode)
   out_grid = sinkhorn.sinkhorn(geometry_grid, a=a, b=b, lse_mode=lse_mode)
   self.assertAllClose(out_mat.reg_ot_cost, out_grid.reg_ot_cost)
Пример #11
0
    def test_apply_transport_geometry_from_scalings(self):
        """Applying transport matrix P on vector without instantiating P."""
        n, m, d = 160, 230, 6
        keys = jax.random.split(self.rng, 6)
        x = jax.random.uniform(keys[0], (n, d))
        y = jax.random.uniform(keys[1], (m, d))
        a = jax.random.uniform(keys[2], (n, ))
        b = jax.random.uniform(keys[3], (m, ))
        a = a / jnp.sum(a)
        b = b / jnp.sum(b)
        transport_t_vec_a = [None, None, None, None]
        transport_vec_b = [None, None, None, None]

        batch_b = 8

        vec_a = jax.random.normal(keys[4], (n, ))
        vec_b = jax.random.normal(keys[5], (batch_b, m))

        # test with lse_mode and online = True / False
        for j, lse_mode in enumerate([True, False]):
            for i, online in enumerate([True, False]):
                geom = pointcloud.PointCloud(x, y, online=online, epsilon=0.2)
                sink = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode)

                u = geom.scaling_from_potential(sink.f)
                v = geom.scaling_from_potential(sink.g)

                transport_t_vec_a[i +
                                  2 * j] = geom.apply_transport_from_scalings(
                                      u, v, vec_a, axis=0)
                transport_vec_b[i +
                                2 * j] = geom.apply_transport_from_scalings(
                                    u, v, vec_b, axis=1)

                transport = geom.transport_from_scalings(u, v)

                self.assertAllClose(transport_t_vec_a[i + 2 * j],
                                    jnp.dot(transport.T, vec_a).T,
                                    rtol=1e-3,
                                    atol=1e-3)
                self.assertAllClose(transport_vec_b[i + 2 * j],
                                    jnp.dot(transport, vec_b.T).T,
                                    rtol=1e-3,
                                    atol=1e-3)
                self.assertIsNot(
                    jnp.any(jnp.isnan(transport_t_vec_a[i + 2 * j])), True)
        for i in range(4):
            self.assertAllClose(transport_vec_b[i],
                                transport_vec_b[0],
                                rtol=1e-3,
                                atol=1e-3)
            self.assertAllClose(transport_t_vec_a[i],
                                transport_t_vec_a[0],
                                rtol=1e-3,
                                atol=1e-3)
Пример #12
0
  def test_apply_transport_grid(self, lse_mode):
    grid_size = (5, 6, 7)
    keys = jax.random.split(self.rng, 3)
    a = jax.random.uniform(keys[0], grid_size)
    b = jax.random.uniform(keys[1], grid_size)
    a = a.ravel() / jnp.sum(a)
    b = b.ravel() / jnp.sum(b)
    geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1)
    x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]]
    xyz = jnp.stack([
        jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
        jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
        jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
    ]).transpose()
    geom_mat = pointcloud.PointCloud(xyz, xyz, epsilon=0.1)
    sink_mat = sinkhorn.sinkhorn(geom_mat, a=a, b=b, lse_mode=lse_mode)
    sink_grid = sinkhorn.sinkhorn(geom_grid, a=a, b=b, lse_mode=lse_mode)

    batch_a = 3
    batch_b = 4
    vec_a = jax.random.normal(keys[4], [batch_a,
                                        np.prod(np.array(grid_size))])
    vec_b = jax.random.normal(keys[4], [batch_b,
                                        np.prod(grid_size)])

    vec_a = vec_a / jnp.sum(vec_a, axis=1)[:, jnp.newaxis]
    vec_b = vec_b / jnp.sum(vec_b, axis=1)[:, jnp.newaxis]

    mat_transport_t_vec_a = geom_mat.apply_transport_from_potentials(
        sink_mat.f, sink_mat.g, vec_a, axis=0)
    mat_transport_vec_b = geom_mat.apply_transport_from_potentials(
        sink_mat.f, sink_mat.g, vec_b, axis=1)

    grid_transport_t_vec_a = geom_grid.apply_transport_from_potentials(
        sink_grid.f, sink_grid.g, vec_a, axis=0)
    grid_transport_vec_b = geom_grid.apply_transport_from_potentials(
        sink_grid.f, sink_grid.g, vec_b, axis=1)

    self.assertAllClose(mat_transport_t_vec_a, grid_transport_t_vec_a)
    self.assertAllClose(mat_transport_vec_b, grid_transport_vec_b)
    self.assertIsNot(jnp.any(jnp.isnan(mat_transport_t_vec_a)), True)
Пример #13
0
 def test_euclidean_point_cloud(self, lse_mode):
     """Two point clouds, tested with various parameters."""
     threshold = 1e-1
     geom = pointcloud.PointCloud(self.x, self.y, epsilon=1, online=True)
     errors = sinkhorn.sinkhorn(geom,
                                a=self.a,
                                b=self.b,
                                threshold=threshold,
                                lse_mode=lse_mode,
                                implicit_differentiation=True).errors
     err = errors[errors > -1][-1]
     self.assertGreater(threshold, err)
Пример #14
0
 def loss_pcg(a, x, implicit=True):
     out = sinkhorn.sinkhorn(pointcloud.PointCloud(x,
                                                   self.y,
                                                   epsilon=epsilon),
                             a=a,
                             b=self.b,
                             tau_a=1.0,
                             tau_b=0.95,
                             threshold=1e-4,
                             lse_mode=lse_mode,
                             implicit_differentiation=implicit)
     return out.reg_ot_cost
Пример #15
0
 def test_euclidean_point_cloud_min_iter(self):
   """Testing the min_iterations parameter."""
   threshold = 1e-3
   geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1)
   errors = sinkhorn.sinkhorn(
       geom, a=self.a, b=self.b, threshold=threshold, min_iterations=34,
       implicit_differentiation=False).errors
   err = errors[jnp.logical_and(errors > -1, jnp.isfinite(errors))][-1]
   self.assertGreater(threshold, err)
   self.assertEqual(jnp.inf, errors[0])
   self.assertEqual(jnp.inf, errors[1])
   self.assertEqual(jnp.inf, errors[2])
   self.assertGreater(errors[3], 0)
Пример #16
0
    def test_restart(self, lse_mode):
        """Two point clouds, tested with various parameters."""
        threshold = 1e-4
        geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.01)
        out = sinkhorn.sinkhorn(geom,
                                a=self.a,
                                b=self.b,
                                threshold=threshold,
                                lse_mode=lse_mode,
                                inner_iterations=1)
        errors = out.errors
        err = errors[errors > -1][-1]
        self.assertGreater(threshold, err)

        # recover solution from previous and ensure faster convergence.
        if lse_mode:
            init_dual_a, init_dual_b = out.f, out.g
        else:
            init_dual_a, init_dual_b = (geom.scaling_from_potential(out.f),
                                        geom.scaling_from_potential(out.g))
        out_restarted = sinkhorn.sinkhorn(geom,
                                          a=self.a,
                                          b=self.b,
                                          threshold=threshold,
                                          lse_mode=lse_mode,
                                          init_dual_a=init_dual_a,
                                          init_dual_b=init_dual_b,
                                          inner_iterations=1)
        errors_restarted = out_restarted.errors
        err_restarted = errors_restarted[errors_restarted > -1][-1]
        self.assertGreater(threshold, err_restarted)

        num_iter_restarted = jnp.sum(errors_restarted > -1)
        # check we can only improve on error
        self.assertGreater(err, err_restarted)
        # check first error in restart does at least as well as previous best
        self.assertGreater(err, errors_restarted[0])
        # check only one iteration suffices when restarting with same data.
        self.assertEqual(num_iter_restarted, 1)
Пример #17
0
 def test_bures_point_cloud(self, lse_mode, online):
     """Two point clouds of Gaussians, tested with various parameters."""
     threshold = 1e-3
     geom = pointcloud.PointCloud(self.x,
                                  self.y,
                                  cost_fn=costs.Bures(dimension=self.dim,
                                                      regularization=1e-4),
                                  online=online,
                                  epsilon=self.eps)
     errors = sinkhorn.sinkhorn(geom, a=self.a, b=self.b,
                                lse_mode=lse_mode).errors
     err = errors[errors > -1][-1]
     self.assertGreater(threshold, err)
Пример #18
0
 def loss_g(a, x, implicit=True):
     out = sinkhorn.sinkhorn(geometry.Geometry(
         cost_matrix=jnp.sum(x**2, axis=1)[:, jnp.newaxis] +
         jnp.sum(self.y**2, axis=1)[jnp.newaxis, :] -
         2 * jnp.dot(x, self.y.T),
         epsilon=epsilon),
                             a=a,
                             b=self.b,
                             tau_a=0.8,
                             tau_b=0.87,
                             threshold=1e-4,
                             lse_mode=lse_mode,
                             implicit_differentiation=implicit)
     return out.reg_ot_cost
Пример #19
0
 def potential(a, x, implicit, d):
     out = sinkhorn.sinkhorn(geometry.Geometry(
         cost_matrix=jnp.sum(x**2, axis=1)[:, jnp.newaxis] +
         jnp.sum(y**2, axis=1)[jnp.newaxis, :] -
         2 * jnp.dot(x, y.T),
         epsilon=epsilon),
                             a=a,
                             b=b,
                             tau_a=tau_a,
                             tau_b=tau_b,
                             lse_mode=lse_mode,
                             threshold=1e-4,
                             implicit_differentiation=implicit,
                             inner_iterations=2)
     return jnp.sum(out.f * d)
Пример #20
0
 def test_euclidean_point_cloud(self, lse_mode, momentum_strategy,
                                inner_iterations, norm_error):
     """Two point clouds, tested with various parameters."""
     threshold = 1e-3
     geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1)
     errors = sinkhorn.sinkhorn(geom,
                                a=self.a,
                                b=self.b,
                                threshold=threshold,
                                momentum_strategy=momentum_strategy,
                                inner_iterations=inner_iterations,
                                norm_error=norm_error,
                                lse_mode=lse_mode).errors
     err = errors[errors > -1][-1]
     self.assertGreater(threshold, err)
Пример #21
0
    def solve(self):
        """Runs the sinkhorn algorithm to solve the transport problem."""
        out = sinkhorn.sinkhorn(self.geom, self.a, self.b, **self._kwargs)
        if not out.converged:
            # TODO(oliviert): Point to the online doc when available.
            logging.warning(
                'Sinkhorn has not converged. Please check your setup and '
                ' consider increasing max_iterations or epsilon. For more'
                ' details, see ott.core.sinkhorn.sinkhorn.')

        # So far we always set the values, even if not converged.
        # TODO(oliviert, cuturi): handles cases where it has not converged.
        self._f = out.f
        self._g = out.g
        self.reg_ot_cost = out.reg_ot_cost
Пример #22
0
 def test_euclidean_point_cloud_parallel_weights(self, lse_mode):
   """Two point clouds, parallel execution for batched histograms."""
   self.rng, *rngs = jax.random.split(self.rng, 2)
   batch = 4
   a = jax.random.uniform(rngs[0], (batch, self.n))
   b = jax.random.uniform(rngs[0], (batch, self.m))
   a = a / jnp.sum(a, axis=1)[:, jnp.newaxis]
   b = b / jnp.sum(b, axis=1)[:, jnp.newaxis]
   threshold = 1e-3
   geom = pointcloud.PointCloud(
       self.x, self.y, epsilon=0.1, online=True)
   errors = sinkhorn.sinkhorn(
       geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode).errors
   err = errors[errors > -1][-1]
   self.assertGreater(jnp.min(threshold - err), 0)
Пример #23
0
    def test_jit_vs_non_jit_fwd(self):
        jitted_result = sinkhorn.sinkhorn(self.geometry, self.a, self.b)
        non_jitted_result = non_jitted_sinkhorn(self.geometry, self.a, self.b)

        def f(g, a, b):
            return non_jitted_sinkhorn(g, a, b)

        user_jitted_result = jax.jit(f)(self.geometry, self.a, self.b)
        chex.assert_tree_all_close(jitted_result,
                                   non_jitted_result,
                                   atol=1e-6,
                                   rtol=0)
        chex.assert_tree_all_close(jitted_result,
                                   user_jitted_result,
                                   atol=1e-6,
                                   rtol=0)
Пример #24
0
  def test_separable_grid(self, lse_mode):
    """Two histograms in a grid of size 5 x 6 x 7  in the hypercube^3."""
    grid_size = (5, 6, 7)
    keys = jax.random.split(self.rng, 2)
    a = jax.random.uniform(keys[0], grid_size)
    b = jax.random.uniform(keys[1], grid_size)
    #  adding zero weights  to test proper handling, then ravel.
    a = jax.ops.index_update(a, 0, 0).ravel()
    a = a / jnp.sum(a)
    b = jax.ops.index_update(b, 3, 0).ravel()
    b = b / jnp.sum(b)

    threshold = 0.01
    geom = grid.Grid(grid_size=grid_size, epsilon=0.1)
    errors = sinkhorn.sinkhorn(
        geom, a=a, b=b, threshold=threshold, lse_mode=lse_mode).errors
    err = errors[jnp.isfinite(errors)][-1]
    self.assertGreater(threshold, err)
Пример #25
0
def _sinkhorn_divergence(geometry_xy: geometry.Geometry,
                         geometry_xx: geometry.Geometry,
                         geometry_yy: Optional[geometry.Geometry],
                         a: jnp.ndarray, b: jnp.ndarray, **kwargs):
    """Computes the (unbalanced) sinkhorn divergence for the wrapper function.

    This definition includes a correction depending on the total masses of each
    measure, as defined in https://arxiv.org/pdf/1910.12958.pdf (15).

  Args:
    geometry_xy: a Cost object able to apply kernels with a certain epsilon,
    between the views X and Y.
    geometry_xx: a Cost object able to apply kernels with a certain epsilon,
    between elements of the view X.
    geometry_yy: a Cost object able to apply kernels with a certain epsilon,
    between elements of the view Y.
    a: jnp.ndarray<float>[n]: the weight of each input point. The sum of
     all elements of b must match that of a to converge.
    b: jnp.ndarray<float>[m]: the weight of each target point. The sum of
     all elements of b must match that of a to converge.
    **kwargs: Arguments to sinkhorn_iterations.
  Returns:
    SinkhornDivergenceOutput named tuple.
  """
    geoms = (geometry_xy, geometry_xx, geometry_yy)
    out = [
        sinkhorn.SinkhornOutput(None, None, 0, None, None) if geom is None else
        sinkhorn.sinkhorn(geom, marginals[0], marginals[1], **kwargs)
        for (geom, marginals) in zip(geoms, [[a, b], [a, a], [b, b]])
    ]
    div = (out[0].reg_ot_cost - 0.5 *
           (out[1].reg_ot_cost + out[2].reg_ot_cost) +
           0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b))**2)
    return SinkhornDivergenceOutput(div, tuple([s.f, s.g] for s in out), geoms,
                                    tuple(s.errors for s in out),
                                    tuple(s.converged for s in out))
Пример #26
0
 def loss_fn(cm):
   a = jnp.ones(cm.shape[0]) / cm.shape[0]
   b = jnp.ones(cm.shape[1]) / cm.shape[1]
   geom = geometry.Geometry(cm, epsilon=0.5)
   out = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode)
   return out.reg_ot_cost, (geom, out.f, out.g)
 def reg_ot(a, b):
   return sinkhorn.sinkhorn(
       geom, a=a, b=b, threshold=0.1, lse_mode=lse_mode,
       implicit_differentiation=False).reg_ot_cost
Пример #28
0
def _gw_iterations(
    geom_x: geometry.Geometry, geom_y: geometry.Geometry, a: jnp.ndarray,
    b: jnp.ndarray, epsilon: Union[epsilon_scheduler.Epsilon,
                                   float], loss: GWLoss, max_iterations: int,
    warm_start: bool, sinkhorn_kwargs: Optional[Dict[str, Any]], **kwargs
) -> Tuple[jnp.ndarray, jnp.ndarray, geometry.Geometry, jnp.ndarray,
           jnp.ndarray, jnp.ndarray]:
    """Fits Gromov Wasserstein.

  Args:
    geom_x: a Geometry object for the first view.
    geom_y: a second Geometry object for the second view.
    a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights.
    b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights.
    epsilon: a regularization parameter or a epsilon_scheduler.Epsilon object.
    loss: GWLoss object.
    max_iterations: int, the maximum number of outer iterations for
     Gromov Wasserstein.
    warm_start: bool, optional initialisation of the potentials/scalings w.r.t.
     first and second marginals between each call to sinkhorn.
    sinkhorn_kwargs: Optionally a dictionary containing the keywords arguments
     for calls to the sinkhorn function.
    **kwargs: additional kwargs for epsilon.

  Returns:
    f: potential.
    g: potential.
    geom_gw: a Geometry object for Gromov-Wasserstein (GW).
    reg_gw_cost_arr: ndarray of regularised GW costs.
    errors_sinkhorn: ndarray [max_iterations, p], where p depends on
     sinkhorn_kwargs, of errors for the Sinkhorn algorithm for each gromov
     iteration (axis 0) and regularly spaced sinkhorn iterations (axis 1).
    converged_sinkhorn: ndarray [max_iterations,] of flags indicating
    that the sinkhorn algorithm converged.
  """
    lse_mode = sinkhorn_kwargs.get('lse_mode', True)

    geom_gw = _init_geometry_gw(geom_x, geom_y, jax.lax.stop_gradient(a),
                                jax.lax.stop_gradient(b), epsilon, loss,
                                **kwargs)
    f, g, reg_gw_cost, errors_sinkhorn, converged_sinkhorn = sinkhorn.sinkhorn(
        geom_gw, a, b, **sinkhorn_kwargs)
    carry = geom_gw, f, g
    update_geom_partial = functools.partial(_update_geometry_gw,
                                            geom_x=geom_x,
                                            geom_y=geom_y,
                                            loss=loss,
                                            **kwargs)
    sinkhorn_partial = functools.partial(sinkhorn.sinkhorn,
                                         a=a,
                                         b=b,
                                         **sinkhorn_kwargs)

    def body_fn(carry=carry, x=None):
        del x
        geom_gw, f, g = carry
        geom_gw = update_geom_partial(geom=geom_gw, f=f, g=g)
        init_dual_a = ((f if lse_mode else geom_gw.scaling_from_potential(f))
                       if warm_start else None)
        init_dual_b = ((g if lse_mode else geom_gw.scaling_from_potential(g))
                       if warm_start else None)
        f, g, reg_gw_cost, errors_sinkhorn, converged_sinkhorn = sinkhorn_partial(
            geom=geom_gw, init_dual_a=init_dual_a, init_dual_b=init_dual_b)
        return (geom_gw, f, g), (reg_gw_cost, errors_sinkhorn,
                                 converged_sinkhorn)

    carry, out = jax.lax.scan(f=body_fn,
                              init=carry,
                              xs=None,
                              length=max_iterations - 1)

    geom_gw, f, g = carry
    reg_gw_cost_arr = jnp.concatenate((jnp.array([reg_gw_cost]), out[0]))
    errors_sinkhorn = jnp.concatenate((jnp.array([errors_sinkhorn]), out[1]))
    converged_sinkhorn = jnp.concatenate(
        (jnp.array([converged_sinkhorn]), out[2]))
    return (f, g, geom_gw, reg_gw_cost_arr, errors_sinkhorn,
            converged_sinkhorn)
Пример #29
0
 def reg_ot(a, b):
     return sinkhorn.sinkhorn(geom,
                              a=a,
                              b=b,
                              threshold=0.001,
                              lse_mode=lse_mode).reg_ot_cost
Пример #30
0
 def reg_ot(a, b):
   return sinkhorn.sinkhorn(geom, a=a, b=b, lse_mode=lse_mode).reg_ot_cost