Пример #1
0
    def test_euclidean_point_cloud(self):
        rngs = jax.random.split(self.rng, 2)
        x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim))
        y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim))
        geometry_xx = pointcloud.PointCloud(x, x, epsilon=0.01)
        geometry_xy = pointcloud.PointCloud(x, y, epsilon=0.01)
        geometry_yy = pointcloud.PointCloud(y, y, epsilon=0.01)
        div = sinkhorn_divergence._sinkhorn_divergence(geometry_xy,
                                                       geometry_xx,
                                                       geometry_yy,
                                                       self._a,
                                                       self._b,
                                                       threshold=1e-2)
        self.assertGreater(div.divergence, 0.0)
        self.assertLen(div.potentials, 3)

        # Test symmetric setting,
        # test that symmetric evaluation converges earlier/better.
        div = sinkhorn_divergence.sinkhorn_divergence(
            pointcloud.PointCloud,
            x,
            x,
            epsilon=1e-1,
            sinkhorn_kwargs={'inner_iterations': 1})
        self.assertAllClose(div.divergence, 0.0, rtol=1e-5, atol=1e-5)
        iters_xx = jnp.sum(div.errors[0] > 0)
        iters_xx_sym = jnp.sum(div.errors[1] > 0)
        self.assertGreater(iters_xx, iters_xx_sym)
Пример #2
0
    def test_online_vs_batch_euclidean_point_cloud(self, lse_mode):
        """Comparing online vs batch geometry."""
        threshold = 1e-3
        eps = 0.1
        online_geom = pointcloud.PointCloud(self.x,
                                            self.y,
                                            epsilon=eps,
                                            online=True)
        online_geom_euc = pointcloud.PointCloud(self.x,
                                                self.y,
                                                cost_fn=costs.Euclidean(),
                                                epsilon=eps,
                                                online=True)

        batch_geom = pointcloud.PointCloud(self.x, self.y, epsilon=eps)
        batch_geom_euc = pointcloud.PointCloud(self.x,
                                               self.y,
                                               cost_fn=costs.Euclidean(),
                                               epsilon=eps)

        out_online = sinkhorn.sinkhorn(online_geom,
                                       a=self.a,
                                       b=self.b,
                                       threshold=threshold,
                                       lse_mode=lse_mode)
        out_batch = sinkhorn.sinkhorn(batch_geom,
                                      a=self.a,
                                      b=self.b,
                                      threshold=threshold,
                                      lse_mode=lse_mode)
        out_online_euc = sinkhorn.sinkhorn(online_geom_euc,
                                           a=self.a,
                                           b=self.b,
                                           threshold=threshold,
                                           lse_mode=lse_mode)
        out_batch_euc = sinkhorn.sinkhorn(batch_geom_euc,
                                          a=self.a,
                                          b=self.b,
                                          threshold=threshold,
                                          lse_mode=lse_mode)

        # Checks regularized transport costs match.
        self.assertAllClose(out_online.reg_ot_cost, out_batch.reg_ot_cost)
        # check regularized transport matrices match
        self.assertAllClose(
            online_geom.transport_from_potentials(out_online.f, out_online.g),
            batch_geom.transport_from_potentials(out_batch.f, out_batch.g))

        self.assertAllClose(
            online_geom_euc.transport_from_potentials(out_online_euc.f,
                                                      out_online_euc.g),
            batch_geom_euc.transport_from_potentials(out_batch_euc.f,
                                                     out_batch_euc.g))

        self.assertAllClose(
            batch_geom.transport_from_potentials(out_batch.f, out_batch.g),
            batch_geom_euc.transport_from_potentials(out_batch_euc.f,
                                                     out_batch_euc.g))
Пример #3
0
 def test_geom_vs_point_cloud(self):
     """Two point clouds vs. simple cost_matrix execution of sinkorn."""
     geom = pointcloud.PointCloud(self.x, self.y)
     geom_2 = geometry.Geometry(geom.cost_matrix)
     f = sinkhorn.sinkhorn(geom, a=self.a, b=self.b).f
     f_2 = sinkhorn.sinkhorn(geom_2, a=self.a, b=self.b).f
     self.assertAllClose(f, f_2)
Пример #4
0
def sinkhorn_for_sort(inputs: jnp.ndarray, weights: jnp.ndarray,
                      target_weights: jnp.ndarray, sinkhorn_kw,
                      pointcloud_kw) -> jnp.ndarray:
    """Runs sinkhorn on a fixed increasing target.

  Args:
    inputs: jnp.ndarray[num_points]. Must be one dimensional.
    weights: jnp.ndarray[num_points]. The weights 'a' for the inputs.
    target_weights: jnp.ndarray[num_targets]: the weights of the targets. It may
      be of a different size than the weights.
    sinkhorn_kw: a dictionary holding the sinkhorn keyword arguments. See
      sinkhorn.py for more details.
    pointcloud_kw: a dictionary holding the keyword arguments of the
      PointCloud class. See pointcloud.py for more details.

  Returns:
    A jnp.ndarray<float> representing the transport matrix of the inputs onto
    the underlying sorted target.
  """
    shape = inputs.shape
    if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1):
        raise ValueError(
            "Shape ({shape}) not supported. The input should be one-dimensional."
        )

    x = jnp.expand_dims(jnp.squeeze(inputs), axis=1)
    x = jax.nn.sigmoid((x - jnp.mean(x)) / (jnp.std(x) + 1e-10))
    a = jnp.squeeze(weights)
    b = jnp.squeeze(target_weights)
    num_targets = b.shape[0]
    y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis]
    geom = pointcloud.PointCloud(x, y, **pointcloud_kw)
    res = sinkhorn.sinkhorn(geom, a, b, **sinkhorn_kw)
    return geom.transport_from_potentials(res.f, res.g)
Пример #5
0
    def test_discrete_barycenter_pointcloud(self, lse_mode, epsilon):
        """Tests the discrete barycenters on pointclouds.

    Two measures supported on the same set of points (a 1D grid), barycenter is
    evaluated on a different set of points (still in 1D).

    Args:
      lse_mode: bool, lse or scaling computations
      epsilon: float
    """
        n = 50
        ma = 0.2
        mb = 0.8
        # define two narrow Gaussian bumps in segment [0,1]
        a = jnp.exp(-(jnp.arange(0, n) / (n - 1) - ma)**2 / .01) + 1e-10
        b = jnp.exp(-(jnp.arange(0, n) / (n - 1) - mb)**2 / .01) + 1e-10
        a = a / jnp.sum(a)
        b = b / jnp.sum(b)

        # positions on the real line where weights are supported.
        x = jnp.atleast_2d(jnp.arange(0, n) / (n - 1)).T

        # choose a different support, half the size, for the barycenter.
        # note this is the reason why we do not use debiasing in this case.
        x_support_bar = jnp.atleast_2d((jnp.arange(0, (n / 2)) /
                                        (n / 2 - 1) - .5) * .9 + .5).T

        geom = pointcloud.PointCloud(x, x_support_bar, epsilon=epsilon)
        bar = db.discrete_barycenter(geom,
                                     a=jnp.stack((a, b)),
                                     lse_mode=lse_mode).histogram
        # check the barycenter has bump in the middle.
        self.assertGreater(bar[n // 4], 0.1)
Пример #6
0
  def test_autograd_sinkhorn(self, lse_mode):
    """Test gradient w.r.t. probability weights."""
    d = 3
    n, m = 11, 13
    eps = 1e-3  # perturbation magnitude
    keys = jax.random.split(self.rng, 5)
    x = jax.random.uniform(keys[0], (n, d))
    y = jax.random.uniform(keys[1], (m, d))
    a = jax.random.uniform(keys[2], (n,)) + eps
    b = jax.random.uniform(keys[3], (m,)) + eps
    # Adding zero weights to test proper handling
    a = jax.ops.index_update(a, 0, 0)
    b = jax.ops.index_update(b, 3, 0)
    a = a / jnp.sum(a)
    b = b / jnp.sum(b)
    geom = pointcloud.PointCloud(x, y, epsilon=0.1)

    def reg_ot(a, b):
      return sinkhorn.sinkhorn(geom, a=a, b=b, lse_mode=lse_mode).reg_ot_cost

    reg_ot_and_grad = jax.jit(jax.value_and_grad(reg_ot))
    _, grad_reg_ot = reg_ot_and_grad(a, b)
    delta = jax.random.uniform(keys[4], (n,))
    delta = delta * (a > 0)  # ensures only perturbing non-zero coords.
    delta = delta - jnp.sum(delta) / jnp.sum(a > 0)  # center perturbation
    delta = delta * (a > 0)  # ensures only perturbing non-zero coords.
    reg_ot_delta_plus = reg_ot(a + eps * delta, b)
    reg_ot_delta_minus = reg_ot(a - eps * delta, b)
    delta_dot_grad = jnp.nansum(delta * grad_reg_ot)
    self.assertIsNot(jnp.any(jnp.isnan(delta_dot_grad)), True)
    self.assertAllClose(delta_dot_grad,
                        (reg_ot_delta_plus - reg_ot_delta_minus) / (2 * eps),
                        rtol=1e-03, atol=1e-02)
Пример #7
0
 def test_euclidean_point_cloud(self):
     rngs = jax.random.split(self.rng, 2)
     x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim))
     y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim))
     geometry_xx = pointcloud.PointCloud(x, x, epsilon=0.1)
     geometry_xy = pointcloud.PointCloud(x, y, epsilon=0.1)
     geometry_yy = pointcloud.PointCloud(y, y, epsilon=0.1)
     div = sinkhorn_divergence._sinkhorn_divergence(geometry_xy,
                                                    geometry_xx,
                                                    geometry_yy,
                                                    self._a,
                                                    self._b,
                                                    threshold=1e-1,
                                                    max_iterations=20)
     # div.divergence = 2.0
     self.assertGreater(div.divergence, 0.0)
     self.assertLen(div.potentials, 3)
 def loss_fn(x, y):
     geom = pointcloud.PointCloud(x, y, epsilon=epsilon)
     f, g, regularized_transport_cost, _, _ = sinkhorn.sinkhorn(
         geom,
         a,
         b,
         lse_mode=lse_mode,
         implicit_differentiation=implicit_differentiation)
     return regularized_transport_cost, (geom, f, g)
Пример #9
0
 def test_online_euclidean_point_cloud(self, lse_mode):
   """Testing the online way to handle geometry."""
   threshold = 1e-3
   geom = pointcloud.PointCloud(
       self.x, self.y, epsilon=0.1, online=True)
   errors = sinkhorn.sinkhorn(
       geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode).errors
   err = errors[errors > -1][-1]
   self.assertGreater(threshold, err)
Пример #10
0
 def loss_fn(x, y):
     geom = pointcloud.PointCloud(x, y, epsilon=0.01)
     f, g, regularized_transport_cost, _, _ = sinkhorn.sinkhorn(
         geom,
         a,
         b,
         momentum_strategy=momentum_strategy,
         lse_mode=lse_mode)
     return regularized_transport_cost, (geom, f, g)
Пример #11
0
    def test_apply_transport_geometry_from_scalings(self):
        """Applying transport matrix P on vector without instantiating P."""
        n, m, d = 160, 230, 6
        keys = jax.random.split(self.rng, 6)
        x = jax.random.uniform(keys[0], (n, d))
        y = jax.random.uniform(keys[1], (m, d))
        a = jax.random.uniform(keys[2], (n, ))
        b = jax.random.uniform(keys[3], (m, ))
        a = a / jnp.sum(a)
        b = b / jnp.sum(b)
        transport_t_vec_a = [None, None, None, None]
        transport_vec_b = [None, None, None, None]

        batch_b = 8

        vec_a = jax.random.normal(keys[4], (n, ))
        vec_b = jax.random.normal(keys[5], (batch_b, m))

        # test with lse_mode and online = True / False
        for j, lse_mode in enumerate([True, False]):
            for i, online in enumerate([True, False]):
                geom = pointcloud.PointCloud(x, y, online=online, epsilon=0.2)
                sink = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode)

                u = geom.scaling_from_potential(sink.f)
                v = geom.scaling_from_potential(sink.g)

                transport_t_vec_a[i +
                                  2 * j] = geom.apply_transport_from_scalings(
                                      u, v, vec_a, axis=0)
                transport_vec_b[i +
                                2 * j] = geom.apply_transport_from_scalings(
                                    u, v, vec_b, axis=1)

                transport = geom.transport_from_scalings(u, v)

                self.assertAllClose(transport_t_vec_a[i + 2 * j],
                                    jnp.dot(transport.T, vec_a).T,
                                    rtol=1e-3,
                                    atol=1e-3)
                self.assertAllClose(transport_vec_b[i + 2 * j],
                                    jnp.dot(transport, vec_b.T).T,
                                    rtol=1e-3,
                                    atol=1e-3)
                self.assertIsNot(
                    jnp.any(jnp.isnan(transport_t_vec_a[i + 2 * j])), True)
        for i in range(4):
            self.assertAllClose(transport_vec_b[i],
                                transport_vec_b[0],
                                rtol=1e-3,
                                atol=1e-3)
            self.assertAllClose(transport_t_vec_a[i],
                                transport_t_vec_a[0],
                                rtol=1e-3,
                                atol=1e-3)
Пример #12
0
    def test_transport_wrong_init(self):
        rngs = jax.random.split(self.rng, 2)
        num_a, num_b = 23, 48
        x = jax.random.uniform(rngs[0], (num_a, 4))
        y = jax.random.uniform(rngs[1], (num_b, 4))
        geom = pointcloud.PointCloud(x, y, epsilon=1e-3, online=True)
        with self.assertRaises(ValueError):
            transport.Transport(geom, x, threshold=1e-3)

        with self.assertRaises(AttributeError):
            transport.Transport(x, y, x, y, threshold=1e-3)
Пример #13
0
 def loss_pcg(a, x, implicit=True):
     out = sinkhorn.sinkhorn(pointcloud.PointCloud(x,
                                                   self.y,
                                                   epsilon=epsilon),
                             a=a,
                             b=self.b,
                             tau_a=1.0,
                             tau_b=0.95,
                             threshold=1e-4,
                             lse_mode=lse_mode,
                             implicit_differentiation=implicit)
     return out.reg_ot_cost
Пример #14
0
 def test_transport_from_geom(self):
     rngs = jax.random.split(self.rng, 3)
     num_a, num_b = 23, 48
     x = jax.random.uniform(rngs[0], (num_a, 4))
     y = jax.random.uniform(rngs[1], (num_b, 4))
     geom = pointcloud.PointCloud(x, y, epsilon=1e-3, online=True)
     b = jax.random.uniform(rngs[2], (num_b, ))
     b /= jnp.sum(b)
     ot = transport.Transport(geom, b=b, threshold=1e-3)
     self.assertEqual(ot.matrix.shape, (num_a, num_b))
     self.assertAllClose(jnp.sum(ot.matrix, axis=1), ot.a, atol=1e-3)
     self.assertAllClose(jnp.sum(ot.matrix, axis=0), ot.b, atol=1e-3)
Пример #15
0
 def test_euclidean_point_cloud(self, lse_mode):
     """Two point clouds, tested with various parameters."""
     threshold = 1e-1
     geom = pointcloud.PointCloud(self.x, self.y, epsilon=1, online=True)
     errors = sinkhorn.sinkhorn(geom,
                                a=self.a,
                                b=self.b,
                                threshold=threshold,
                                lse_mode=lse_mode,
                                implicit_differentiation=True).errors
     err = errors[errors > -1][-1]
     self.assertGreater(threshold, err)
Пример #16
0
 def test_euclidean_point_cloud_min_iter(self):
   """Testing the min_iterations parameter."""
   threshold = 1e-3
   geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1)
   errors = sinkhorn.sinkhorn(
       geom, a=self.a, b=self.b, threshold=threshold, min_iterations=34,
       implicit_differentiation=False).errors
   err = errors[jnp.logical_and(errors > -1, jnp.isfinite(errors))][-1]
   self.assertGreater(threshold, err)
   self.assertEqual(jnp.inf, errors[0])
   self.assertEqual(jnp.inf, errors[1])
   self.assertEqual(jnp.inf, errors[2])
   self.assertGreater(errors[3], 0)
Пример #17
0
 def test_bures_point_cloud(self, lse_mode, online):
     """Two point clouds of Gaussians, tested with various parameters."""
     threshold = 1e-3
     geom = pointcloud.PointCloud(self.x,
                                  self.y,
                                  cost_fn=costs.Bures(dimension=self.dim,
                                                      regularization=1e-4),
                                  online=online,
                                  epsilon=self.eps)
     errors = sinkhorn.sinkhorn(geom, a=self.a, b=self.b,
                                lse_mode=lse_mode).errors
     err = errors[errors > -1][-1]
     self.assertGreater(threshold, err)
Пример #18
0
 def test_euclidean_point_cloud(self, lse_mode, momentum_strategy,
                                inner_iterations, norm_error):
     """Two point clouds, tested with various parameters."""
     threshold = 1e-3
     geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1)
     errors = sinkhorn.sinkhorn(geom,
                                a=self.a,
                                b=self.b,
                                threshold=threshold,
                                momentum_strategy=momentum_strategy,
                                inner_iterations=inner_iterations,
                                norm_error=norm_error,
                                lse_mode=lse_mode).errors
     err = errors[errors > -1][-1]
     self.assertGreater(threshold, err)
Пример #19
0
 def test_euclidean_point_cloud_parallel_weights(self, lse_mode):
   """Two point clouds, parallel execution for batched histograms."""
   self.rng, *rngs = jax.random.split(self.rng, 2)
   batch = 4
   a = jax.random.uniform(rngs[0], (batch, self.n))
   b = jax.random.uniform(rngs[0], (batch, self.m))
   a = a / jnp.sum(a, axis=1)[:, jnp.newaxis]
   b = b / jnp.sum(b, axis=1)[:, jnp.newaxis]
   threshold = 1e-3
   geom = pointcloud.PointCloud(
       self.x, self.y, epsilon=0.1, online=True)
   errors = sinkhorn.sinkhorn(
       geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode).errors
   err = errors[errors > -1][-1]
   self.assertGreater(jnp.min(threshold - err), 0)
Пример #20
0
  def test_apply_cost(self):
    grid_size = (5, 6, 7)

    geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1)
    x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]]
    xyz = jnp.stack([
        jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
        jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
        jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
    ]).transpose()
    geom_mat = pointcloud.PointCloud(xyz, xyz, epsilon=0.1)

    vec = jax.random.uniform(self.rng, grid_size).ravel()
    self.assertAllClose(geom_mat.apply_cost(vec),
                        geom_grid.apply_cost(vec))

    self.assertAllClose(
        geom_grid.apply_cost(vec)[0, :], np.dot(geom_mat.cost_matrix, vec))
Пример #21
0
 def test_grid_vs_euclidean(self, lse_mode):
   grid_size = (5, 6, 7)
   keys = jax.random.split(self.rng, 2)
   a = jax.random.uniform(keys[0], grid_size)
   b = jax.random.uniform(keys[1], grid_size)
   a = a.ravel() / jnp.sum(a)
   b = b.ravel() / jnp.sum(b)
   epsilon = 0.1
   geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon)
   x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]]
   xyz = jnp.stack([
       jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
       jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
       jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
   ]).transpose()
   geometry_mat = pointcloud.PointCloud(xyz, xyz, epsilon=epsilon)
   out_mat = sinkhorn.sinkhorn(geometry_mat, a=a, b=b, lse_mode=lse_mode)
   out_grid = sinkhorn.sinkhorn(geometry_grid, a=a, b=b, lse_mode=lse_mode)
   self.assertAllClose(out_mat.reg_ot_cost, out_grid.reg_ot_cost)
Пример #22
0
  def test_apply_transport_grid(self, lse_mode):
    grid_size = (5, 6, 7)
    keys = jax.random.split(self.rng, 3)
    a = jax.random.uniform(keys[0], grid_size)
    b = jax.random.uniform(keys[1], grid_size)
    a = a.ravel() / jnp.sum(a)
    b = b.ravel() / jnp.sum(b)
    geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1)
    x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]]
    xyz = jnp.stack([
        jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
        jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
        jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
    ]).transpose()
    geom_mat = pointcloud.PointCloud(xyz, xyz, epsilon=0.1)
    sink_mat = sinkhorn.sinkhorn(geom_mat, a=a, b=b, lse_mode=lse_mode)
    sink_grid = sinkhorn.sinkhorn(geom_grid, a=a, b=b, lse_mode=lse_mode)

    batch_a = 3
    batch_b = 4
    vec_a = jax.random.normal(keys[4], [batch_a,
                                        np.prod(np.array(grid_size))])
    vec_b = jax.random.normal(keys[4], [batch_b,
                                        np.prod(grid_size)])

    vec_a = vec_a / jnp.sum(vec_a, axis=1)[:, jnp.newaxis]
    vec_b = vec_b / jnp.sum(vec_b, axis=1)[:, jnp.newaxis]

    mat_transport_t_vec_a = geom_mat.apply_transport_from_potentials(
        sink_mat.f, sink_mat.g, vec_a, axis=0)
    mat_transport_vec_b = geom_mat.apply_transport_from_potentials(
        sink_mat.f, sink_mat.g, vec_b, axis=1)

    grid_transport_t_vec_a = geom_grid.apply_transport_from_potentials(
        sink_grid.f, sink_grid.g, vec_a, axis=0)
    grid_transport_vec_b = geom_grid.apply_transport_from_potentials(
        sink_grid.f, sink_grid.g, vec_b, axis=1)

    self.assertAllClose(mat_transport_t_vec_a, grid_transport_t_vec_a)
    self.assertAllClose(mat_transport_vec_b, grid_transport_vec_b)
    self.assertIsNot(jnp.any(jnp.isnan(mat_transport_t_vec_a)), True)
Пример #23
0
    def __init__(self, *args, a=None, b=None, **kwargs):
        """Initialization.

    Args:
      *args: can be either a single argument, the geometry.Geometry instance, or
        for convenience only two jnp.ndarray<float> corresponding to two point
        clouds. In that case the regularization parameter epsilon must be set in
        the kwargs.
      a: the weights of the source.
      b: the weights of the target.
      **kwargs: the keyword arguments passed to the sinkhorn algorithm. If the
        first argument is made of two arrays, kwargs must contain epsilon.

    Raises:
      A ValueError in the case the Geometry cannot be defined by the input
      parameters.
    """
        if len(args) == 1:
            if not isinstance(args[0], geometry.Geometry):
                raise ValueError(
                    'A transport problem must be defined by either a '
                    'single geometry, or two arrays.')
            self.geom = args[0]
        else:
            pc_kw = {}
            for key in ['epsilon', 'cost_fn', 'power', 'online']:
                value = kwargs.pop(key, None)
                if value is not None:
                    pc_kw[key] = value
            self.geom = pointcloud.PointCloud(*args, **pc_kw)

        num_a, num_b = self.geom.shape
        self.a = jnp.ones((num_a, )) / num_a if a is None else a
        self.b = jnp.ones((num_b, )) / num_b if b is None else b
        self._f = None
        self._g = None
        self._kwargs = kwargs
        self.reg_ot_cost = None
        self.solve()
Пример #24
0
    def test_restart(self, lse_mode):
        """Two point clouds, tested with various parameters."""
        threshold = 1e-4
        geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.01)
        out = sinkhorn.sinkhorn(geom,
                                a=self.a,
                                b=self.b,
                                threshold=threshold,
                                lse_mode=lse_mode,
                                inner_iterations=1)
        errors = out.errors
        err = errors[errors > -1][-1]
        self.assertGreater(threshold, err)

        # recover solution from previous and ensure faster convergence.
        if lse_mode:
            init_dual_a, init_dual_b = out.f, out.g
        else:
            init_dual_a, init_dual_b = (geom.scaling_from_potential(out.f),
                                        geom.scaling_from_potential(out.g))
        out_restarted = sinkhorn.sinkhorn(geom,
                                          a=self.a,
                                          b=self.b,
                                          threshold=threshold,
                                          lse_mode=lse_mode,
                                          init_dual_a=init_dual_a,
                                          init_dual_b=init_dual_b,
                                          inner_iterations=1)
        errors_restarted = out_restarted.errors
        err_restarted = errors_restarted[errors_restarted > -1][-1]
        self.assertGreater(threshold, err_restarted)

        num_iter_restarted = jnp.sum(errors_restarted > -1)
        # check we can only improve on error
        self.assertGreater(err, err_restarted)
        # check first error in restart does at least as well as previous best
        self.assertGreater(err, errors_restarted[0])
        # check only one iteration suffices when restarting with same data.
        self.assertEqual(num_iter_restarted, 1)
  def test_apply_cost_and_kernel(self):
    """Test consistency of cost/kernel apply to vec."""
    n, m, p, b = 5, 8, 10, 7
    keys = jax.random.split(self.rng, 5)
    x = jax.random.normal(keys[0], (n, p))
    y = jax.random.normal(keys[1], (m, p)) + 1
    cost = jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1)
    vec0 = jax.random.normal(keys[2], (n, b))
    vec1 = jax.random.normal(keys[3], (m, b))

    geom = pointcloud.PointCloud(x, y, power=2, online=True)
    prod0_online = geom.apply_cost(vec0, axis=0)
    prod1_online = geom.apply_cost(vec1, axis=1)
    geom = pointcloud.PointCloud(x, y, power=2, online=False)
    prod0 = geom.apply_cost(vec0, axis=0)
    prod1 = geom.apply_cost(vec1, axis=1)
    geom = geometry.Geometry(cost)
    prod0_geom = geom.apply_cost(vec0, axis=0)
    prod1_geom = geom.apply_cost(vec1, axis=1)
    self.assertAllClose(prod0_online, prod0, rtol=1e-03, atol=1e-02)
    self.assertAllClose(prod1_online, prod1, rtol=1e-03, atol=1e-02)
    self.assertAllClose(prod0_geom, prod0, rtol=1e-03, atol=1e-02)
    self.assertAllClose(prod1_geom, prod1, rtol=1e-03, atol=1e-02)

    geom = pointcloud.PointCloud(x, y, power=1, online=True)
    prod0_online = geom.apply_cost(vec0, axis=0)
    prod1_online = geom.apply_cost(vec1, axis=1)
    geom = pointcloud.PointCloud(x, y, power=1, online=False)
    prod0 = geom.apply_cost(vec0, axis=0)
    prod1 = geom.apply_cost(vec1, axis=1)
    self.assertAllClose(prod0_online, prod0, rtol=1e-03, atol=1e-02)
    self.assertAllClose(prod1_online, prod1, rtol=1e-03, atol=1e-02)

    geom = pointcloud.PointCloud(x, y, power=2, online=True)
    prod0_online = geom.apply_kernel(vec0, 1., axis=0)
    prod1_online = geom.apply_kernel(vec1, 1., axis=1)
    geom = pointcloud.PointCloud(x, y, power=2, online=False)
    prod0 = geom.apply_kernel(vec0, 1., axis=0)
    prod1 = geom.apply_kernel(vec1, 1., axis=1)
    self.assertAllClose(prod0_online, prod0, rtol=1e-03, atol=1e-02)
    self.assertAllClose(prod1_online, prod1, rtol=1e-03, atol=1e-02)