Esempio n. 1
0
    def test_generic_point_cloud_wrapper(self):
        rngs = jax.random.split(self.rng, 2)
        x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim))
        y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim))

        # Tests with 3 cost matrices passed as args
        cxy = jnp.sum(jnp.abs(x[:, jnp.newaxis] - y[jnp.newaxis, :])**2,
                      axis=2)
        cxx = jnp.sum(jnp.abs(x[:, jnp.newaxis] - x[jnp.newaxis, :])**2,
                      axis=2)
        cyy = jnp.sum(jnp.abs(y[:, jnp.newaxis] - y[jnp.newaxis, :])**2,
                      axis=2)
        div = sinkhorn_divergence.sinkhorn_divergence(
            geometry.Geometry,
            cxy,
            cxx,
            cyy,
            epsilon=0.1,
            a=self._a,
            b=self._b,
            sinkhorn_kwargs=dict(threshold=1e-2))
        self.assertIsNotNone(div.divergence)
        self.assertLen(div.potentials, 3)
        self.assertLen(div.geoms, 3)

        # Tests with 2 cost matrices passed as args
        div = sinkhorn_divergence.sinkhorn_divergence(
            geometry.Geometry,
            cxy,
            cxx,
            epsilon=0.1,
            a=self._a,
            b=self._b,
            sinkhorn_kwargs=dict(threshold=1e-2))
        self.assertIsNotNone(div.divergence)
        self.assertLen(div.potentials, 3)
        self.assertLen(div.geoms, 3)

        # Tests with 3 cost matrices passed as kwargs
        div = sinkhorn_divergence.sinkhorn_divergence(
            geometry.Geometry,
            cost_matrix=(cxy, cxx, cyy),
            epsilon=0.1,
            a=self._a,
            b=self._b,
            sinkhorn_kwargs=dict(threshold=1e-2))
        self.assertIsNotNone(div.divergence)
        self.assertLen(div.potentials, 3)
        self.assertLen(div.geoms, 3)
Esempio n. 2
0
    def test_euclidean_point_cloud(self):
        rngs = jax.random.split(self.rng, 2)
        x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim))
        y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim))
        geometry_xx = pointcloud.PointCloud(x, x, epsilon=0.01)
        geometry_xy = pointcloud.PointCloud(x, y, epsilon=0.01)
        geometry_yy = pointcloud.PointCloud(y, y, epsilon=0.01)
        div = sinkhorn_divergence._sinkhorn_divergence(geometry_xy,
                                                       geometry_xx,
                                                       geometry_yy,
                                                       self._a,
                                                       self._b,
                                                       threshold=1e-2)
        self.assertGreater(div.divergence, 0.0)
        self.assertLen(div.potentials, 3)

        # Test symmetric setting,
        # test that symmetric evaluation converges earlier/better.
        div = sinkhorn_divergence.sinkhorn_divergence(
            pointcloud.PointCloud,
            x,
            x,
            epsilon=1e-1,
            sinkhorn_kwargs={'inner_iterations': 1})
        self.assertAllClose(div.divergence, 0.0, rtol=1e-5, atol=1e-5)
        iters_xx = jnp.sum(div.errors[0] > 0)
        iters_xx_sym = jnp.sum(div.errors[1] > 0)
        self.assertGreater(iters_xx, iters_xx_sym)
 def loss_fn(cloud_a, cloud_b):
     div = sinkhorn_divergence.sinkhorn_divergence(
         pointcloud.PointCloud,
         cloud_a,
         cloud_b,
         epsilon=1.0,
         a=self._a,
         b=self._b,
         sinkhorn_kwargs=dict(threshold=0.05))
     return div.divergence
Esempio n. 4
0
 def test_euclidean_point_cloud_wrapper_no_weights(self):
     rngs = jax.random.split(self.rng, 2)
     cloud_a = jax.random.uniform(rngs[0], (self._num_points[0], self._dim))
     cloud_b = jax.random.uniform(rngs[1], (self._num_points[1], self._dim))
     div = sinkhorn_divergence.sinkhorn_divergence(
         pointcloud.PointCloud,
         cloud_a,
         cloud_b,
         epsilon=0.1,
         sinkhorn_kwargs=dict(threshold=1e-2))
     self.assertGreater(div.divergence, 0.0)
     self.assertLen(div.potentials, 3)
     self.assertLen(div.geoms, 3)
Esempio n. 5
0
 def test_euclidean_point_cloud_unbalanced_wrapper(self):
     rngs = jax.random.split(self.rng, 2)
     cloud_a = jax.random.uniform(rngs[0], (self._num_points[0], self._dim))
     cloud_b = jax.random.uniform(rngs[1], (self._num_points[1], self._dim))
     div = sinkhorn_divergence.sinkhorn_divergence(pointcloud.PointCloud,
                                                   cloud_a,
                                                   cloud_b,
                                                   epsilon=0.1,
                                                   a=self._a + .001,
                                                   b=self._b + .002,
                                                   sinkhorn_kwargs=dict(
                                                       threshold=1e-2,
                                                       tau_a=0.8,
                                                       tau_b=0.9))
     self.assertGreater(div.divergence, 0.0)
     self.assertLen(div.potentials, 3)
     self.assertLen(div.geoms, 3)