示例#1
0
    def test_discrete_barycenter_grid(self, lse_mode, debiased, epsilon):
        """Tests the discrete barycenters on a 5x5x5 grid.

    Puts two masses on opposing ends of the hypercube with small noise in
    between. Check that their W barycenter sits (mostly) at the middle of the
    hypercube (e.g. index (5x5x5-1)/2)

    Args:
      lse_mode: bool, lse or scaling computations.
      debiased: bool, use (or not) debiasing as proposed in
      https://arxiv.org/abs/2006.02575
      epsilon: float, regularization parameter
    """
        size = jnp.array([5, 5, 5])
        grid_3d = grid.Grid(grid_size=size, epsilon=epsilon)
        a = jnp.ones(size)
        b = jnp.ones(size)
        a = a.ravel()
        b = b.ravel()
        a = jax.ops.index_update(a, 0, 10000)
        b = jax.ops.index_update(b, -1, 10000)
        a = a / jnp.sum(a)
        b = b / jnp.sum(b)
        threshold = 1e-2
        _, _, bar, errors = db.discrete_barycenter(grid_3d,
                                                   a=jnp.stack((a, b)),
                                                   threshold=threshold,
                                                   lse_mode=lse_mode,
                                                   debiased=debiased)
        self.assertGreater(bar[(jnp.prod(size) - 1) // 2], 0.7)
        self.assertGreater(1, bar[(jnp.prod(size) - 1) // 2])
        err = errors[jnp.isfinite(errors)][-1]
        self.assertGreater(threshold, err)
 def reg_ot(x):
     geom = grid.Grid(x=x, epsilon=1.0)
     return sinkhorn.sinkhorn(geom,
                              a=a,
                              b=b,
                              threshold=0.1,
                              lse_mode=lse_mode).reg_ot_cost
示例#3
0
    def test_autograd_sinkhorn_x_grid(self, lse_mode):
        """Test gradient w.r.t. probability weights."""
        eps = 1e-4  # perturbation magnitude
        keys = jax.random.split(self.rng, 3)
        x = (jnp.array([.0, 1.0], dtype=jnp.float32),
             jnp.array([.3, .4, .7], dtype=jnp.float32),
             jnp.array([1.0, 1.3, 2.4, 3.7], dtype=jnp.float32))
        grid_size = tuple([xs.shape[0] for xs in x])
        a = jax.random.uniform(keys[0], grid_size) + 1
        b = jax.random.uniform(keys[1], grid_size) + 1
        a = a.ravel() / jnp.sum(a)
        b = b.ravel() / jnp.sum(b)
        geom = grid.Grid(x=x, epsilon=1)

        def reg_ot(a, b):
            return sinkhorn.sinkhorn(geom,
                                     a=a,
                                     b=b,
                                     threshold=0.001,
                                     lse_mode=lse_mode).reg_ot_cost

        reg_ot_and_grad = jax.jit(jax.value_and_grad(reg_ot))
        _, grad_reg_ot = reg_ot_and_grad(a, b)
        delta = jax.random.uniform(keys[2], grid_size).ravel()
        delta = delta - jnp.mean(delta)

        # center perturbation
        reg_ot_delta_plus = reg_ot(a + eps * delta, b)
        reg_ot_delta_minus = reg_ot(a - eps * delta, b)
        delta_dot_grad = jnp.sum(delta * grad_reg_ot)
        self.assertAllClose(delta_dot_grad,
                            (reg_ot_delta_plus - reg_ot_delta_minus) /
                            (2 * eps),
                            rtol=1e-03,
                            atol=1e-02)
  def test_autograd_sinkhorn_grid(self, lse_mode):
    """Test gradient w.r.t. probability weights."""
    eps = 1e-3  # perturbation magnitude
    keys = jax.random.split(self.rng, 3)
    grid_size = (2, 3, 4)
    a = jax.random.uniform(keys[0], grid_size) + 1.0
    b = jax.random.uniform(keys[1], grid_size) + 1.0
    a = a.ravel() / jnp.sum(a)
    b = b.ravel() / jnp.sum(b)
    geom = grid.Grid(grid_size=grid_size, epsilon=0.2)

    def reg_ot(a, b):
      return sinkhorn.sinkhorn(
          geom, a=a, b=b, threshold=0.1, lse_mode=lse_mode,
          implicit_differentiation=False).reg_ot_cost

    reg_ot_and_grad = jax.jit(jax.value_and_grad(reg_ot))
    _, grad_reg_ot = reg_ot_and_grad(a, b)
    delta = jax.random.uniform(keys[2], grid_size).ravel()
    delta = delta - jnp.mean(delta)

    # center perturbation
    reg_ot_delta_plus = reg_ot(a + eps * delta, b)
    reg_ot_delta_minus = reg_ot(a - eps * delta, b)
    delta_dot_grad = jnp.sum(delta * grad_reg_ot)
    self.assertAllClose(delta_dot_grad,
                        (reg_ot_delta_plus - reg_ot_delta_minus) / (2 * eps),
                        rtol=1e-03, atol=1e-02)
示例#5
0
  def test_separable_grid(self, lse_mode):
    """Two histograms in a grid of size 5 x 6 x 7  in the hypercube^3."""
    grid_size = (5, 6, 7)
    keys = jax.random.split(self.rng, 2)
    a = jax.random.uniform(keys[0], grid_size)
    b = jax.random.uniform(keys[1], grid_size)
    #  adding zero weights  to test proper handling, then ravel.
    a = jax.ops.index_update(a, 0, 0).ravel()
    a = a / jnp.sum(a)
    b = jax.ops.index_update(b, 3, 0).ravel()
    b = b / jnp.sum(b)

    threshold = 0.01
    geom = grid.Grid(grid_size=grid_size, epsilon=0.1)
    errors = sinkhorn.sinkhorn(
        geom, a=a, b=b, threshold=threshold, lse_mode=lse_mode).errors
    err = errors[jnp.isfinite(errors)][-1]
    self.assertGreater(threshold, err)
示例#6
0
  def test_apply_cost(self):
    grid_size = (5, 6, 7)

    geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1)
    x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]]
    xyz = jnp.stack([
        jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
        jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
        jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
    ]).transpose()
    geom_mat = pointcloud.PointCloud(xyz, xyz, epsilon=0.1)

    vec = jax.random.uniform(self.rng, grid_size).ravel()
    self.assertAllClose(geom_mat.apply_cost(vec),
                        geom_grid.apply_cost(vec))

    self.assertAllClose(
        geom_grid.apply_cost(vec)[0, :], np.dot(geom_mat.cost_matrix, vec))
示例#7
0
 def test_grid_vs_euclidean(self, lse_mode):
   grid_size = (5, 6, 7)
   keys = jax.random.split(self.rng, 2)
   a = jax.random.uniform(keys[0], grid_size)
   b = jax.random.uniform(keys[1], grid_size)
   a = a.ravel() / jnp.sum(a)
   b = b.ravel() / jnp.sum(b)
   epsilon = 0.1
   geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon)
   x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]]
   xyz = jnp.stack([
       jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
       jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
       jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
   ]).transpose()
   geometry_mat = pointcloud.PointCloud(xyz, xyz, epsilon=epsilon)
   out_mat = sinkhorn.sinkhorn(geometry_mat, a=a, b=b, lse_mode=lse_mode)
   out_grid = sinkhorn.sinkhorn(geometry_grid, a=a, b=b, lse_mode=lse_mode)
   self.assertAllClose(out_mat.reg_ot_cost, out_grid.reg_ot_cost)
示例#8
0
  def test_apply_transport_grid(self, lse_mode):
    grid_size = (5, 6, 7)
    keys = jax.random.split(self.rng, 3)
    a = jax.random.uniform(keys[0], grid_size)
    b = jax.random.uniform(keys[1], grid_size)
    a = a.ravel() / jnp.sum(a)
    b = b.ravel() / jnp.sum(b)
    geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1)
    x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]]
    xyz = jnp.stack([
        jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1),
        jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1),
        jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1),
    ]).transpose()
    geom_mat = pointcloud.PointCloud(xyz, xyz, epsilon=0.1)
    sink_mat = sinkhorn.sinkhorn(geom_mat, a=a, b=b, lse_mode=lse_mode)
    sink_grid = sinkhorn.sinkhorn(geom_grid, a=a, b=b, lse_mode=lse_mode)

    batch_a = 3
    batch_b = 4
    vec_a = jax.random.normal(keys[4], [batch_a,
                                        np.prod(np.array(grid_size))])
    vec_b = jax.random.normal(keys[4], [batch_b,
                                        np.prod(grid_size)])

    vec_a = vec_a / jnp.sum(vec_a, axis=1)[:, jnp.newaxis]
    vec_b = vec_b / jnp.sum(vec_b, axis=1)[:, jnp.newaxis]

    mat_transport_t_vec_a = geom_mat.apply_transport_from_potentials(
        sink_mat.f, sink_mat.g, vec_a, axis=0)
    mat_transport_vec_b = geom_mat.apply_transport_from_potentials(
        sink_mat.f, sink_mat.g, vec_b, axis=1)

    grid_transport_t_vec_a = geom_grid.apply_transport_from_potentials(
        sink_grid.f, sink_grid.g, vec_a, axis=0)
    grid_transport_vec_b = geom_grid.apply_transport_from_potentials(
        sink_grid.f, sink_grid.g, vec_b, axis=1)

    self.assertAllClose(mat_transport_t_vec_a, grid_transport_t_vec_a)
    self.assertAllClose(mat_transport_vec_b, grid_transport_vec_b)
    self.assertIsNot(jnp.any(jnp.isnan(mat_transport_t_vec_a)), True)