def test_euclidean_point_cloud(self): rngs = jax.random.split(self.rng, 2) x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim)) y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim)) geometry_xx = pointcloud.PointCloud(x, x, epsilon=0.01) geometry_xy = pointcloud.PointCloud(x, y, epsilon=0.01) geometry_yy = pointcloud.PointCloud(y, y, epsilon=0.01) div = sinkhorn_divergence._sinkhorn_divergence(geometry_xy, geometry_xx, geometry_yy, self._a, self._b, threshold=1e-2) self.assertGreater(div.divergence, 0.0) self.assertLen(div.potentials, 3) # Test symmetric setting, # test that symmetric evaluation converges earlier/better. div = sinkhorn_divergence.sinkhorn_divergence( pointcloud.PointCloud, x, x, epsilon=1e-1, sinkhorn_kwargs={'inner_iterations': 1}) self.assertAllClose(div.divergence, 0.0, rtol=1e-5, atol=1e-5) iters_xx = jnp.sum(div.errors[0] > 0) iters_xx_sym = jnp.sum(div.errors[1] > 0) self.assertGreater(iters_xx, iters_xx_sym)
def test_online_vs_batch_euclidean_point_cloud(self, lse_mode): """Comparing online vs batch geometry.""" threshold = 1e-3 eps = 0.1 online_geom = pointcloud.PointCloud(self.x, self.y, epsilon=eps, online=True) online_geom_euc = pointcloud.PointCloud(self.x, self.y, cost_fn=costs.Euclidean(), epsilon=eps, online=True) batch_geom = pointcloud.PointCloud(self.x, self.y, epsilon=eps) batch_geom_euc = pointcloud.PointCloud(self.x, self.y, cost_fn=costs.Euclidean(), epsilon=eps) out_online = sinkhorn.sinkhorn(online_geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode) out_batch = sinkhorn.sinkhorn(batch_geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode) out_online_euc = sinkhorn.sinkhorn(online_geom_euc, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode) out_batch_euc = sinkhorn.sinkhorn(batch_geom_euc, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode) # Checks regularized transport costs match. self.assertAllClose(out_online.reg_ot_cost, out_batch.reg_ot_cost) # check regularized transport matrices match self.assertAllClose( online_geom.transport_from_potentials(out_online.f, out_online.g), batch_geom.transport_from_potentials(out_batch.f, out_batch.g)) self.assertAllClose( online_geom_euc.transport_from_potentials(out_online_euc.f, out_online_euc.g), batch_geom_euc.transport_from_potentials(out_batch_euc.f, out_batch_euc.g)) self.assertAllClose( batch_geom.transport_from_potentials(out_batch.f, out_batch.g), batch_geom_euc.transport_from_potentials(out_batch_euc.f, out_batch_euc.g))
def test_geom_vs_point_cloud(self): """Two point clouds vs. simple cost_matrix execution of sinkorn.""" geom = pointcloud.PointCloud(self.x, self.y) geom_2 = geometry.Geometry(geom.cost_matrix) f = sinkhorn.sinkhorn(geom, a=self.a, b=self.b).f f_2 = sinkhorn.sinkhorn(geom_2, a=self.a, b=self.b).f self.assertAllClose(f, f_2)
def sinkhorn_for_sort(inputs: jnp.ndarray, weights: jnp.ndarray, target_weights: jnp.ndarray, sinkhorn_kw, pointcloud_kw) -> jnp.ndarray: """Runs sinkhorn on a fixed increasing target. Args: inputs: jnp.ndarray[num_points]. Must be one dimensional. weights: jnp.ndarray[num_points]. The weights 'a' for the inputs. target_weights: jnp.ndarray[num_targets]: the weights of the targets. It may be of a different size than the weights. sinkhorn_kw: a dictionary holding the sinkhorn keyword arguments. See sinkhorn.py for more details. pointcloud_kw: a dictionary holding the keyword arguments of the PointCloud class. See pointcloud.py for more details. Returns: A jnp.ndarray<float> representing the transport matrix of the inputs onto the underlying sorted target. """ shape = inputs.shape if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1): raise ValueError( "Shape ({shape}) not supported. The input should be one-dimensional." ) x = jnp.expand_dims(jnp.squeeze(inputs), axis=1) x = jax.nn.sigmoid((x - jnp.mean(x)) / (jnp.std(x) + 1e-10)) a = jnp.squeeze(weights) b = jnp.squeeze(target_weights) num_targets = b.shape[0] y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis] geom = pointcloud.PointCloud(x, y, **pointcloud_kw) res = sinkhorn.sinkhorn(geom, a, b, **sinkhorn_kw) return geom.transport_from_potentials(res.f, res.g)
def test_discrete_barycenter_pointcloud(self, lse_mode, epsilon): """Tests the discrete barycenters on pointclouds. Two measures supported on the same set of points (a 1D grid), barycenter is evaluated on a different set of points (still in 1D). Args: lse_mode: bool, lse or scaling computations epsilon: float """ n = 50 ma = 0.2 mb = 0.8 # define two narrow Gaussian bumps in segment [0,1] a = jnp.exp(-(jnp.arange(0, n) / (n - 1) - ma)**2 / .01) + 1e-10 b = jnp.exp(-(jnp.arange(0, n) / (n - 1) - mb)**2 / .01) + 1e-10 a = a / jnp.sum(a) b = b / jnp.sum(b) # positions on the real line where weights are supported. x = jnp.atleast_2d(jnp.arange(0, n) / (n - 1)).T # choose a different support, half the size, for the barycenter. # note this is the reason why we do not use debiasing in this case. x_support_bar = jnp.atleast_2d((jnp.arange(0, (n / 2)) / (n / 2 - 1) - .5) * .9 + .5).T geom = pointcloud.PointCloud(x, x_support_bar, epsilon=epsilon) bar = db.discrete_barycenter(geom, a=jnp.stack((a, b)), lse_mode=lse_mode).histogram # check the barycenter has bump in the middle. self.assertGreater(bar[n // 4], 0.1)
def test_autograd_sinkhorn(self, lse_mode): """Test gradient w.r.t. probability weights.""" d = 3 n, m = 11, 13 eps = 1e-3 # perturbation magnitude keys = jax.random.split(self.rng, 5) x = jax.random.uniform(keys[0], (n, d)) y = jax.random.uniform(keys[1], (m, d)) a = jax.random.uniform(keys[2], (n,)) + eps b = jax.random.uniform(keys[3], (m,)) + eps # Adding zero weights to test proper handling a = jax.ops.index_update(a, 0, 0) b = jax.ops.index_update(b, 3, 0) a = a / jnp.sum(a) b = b / jnp.sum(b) geom = pointcloud.PointCloud(x, y, epsilon=0.1) def reg_ot(a, b): return sinkhorn.sinkhorn(geom, a=a, b=b, lse_mode=lse_mode).reg_ot_cost reg_ot_and_grad = jax.jit(jax.value_and_grad(reg_ot)) _, grad_reg_ot = reg_ot_and_grad(a, b) delta = jax.random.uniform(keys[4], (n,)) delta = delta * (a > 0) # ensures only perturbing non-zero coords. delta = delta - jnp.sum(delta) / jnp.sum(a > 0) # center perturbation delta = delta * (a > 0) # ensures only perturbing non-zero coords. reg_ot_delta_plus = reg_ot(a + eps * delta, b) reg_ot_delta_minus = reg_ot(a - eps * delta, b) delta_dot_grad = jnp.nansum(delta * grad_reg_ot) self.assertIsNot(jnp.any(jnp.isnan(delta_dot_grad)), True) self.assertAllClose(delta_dot_grad, (reg_ot_delta_plus - reg_ot_delta_minus) / (2 * eps), rtol=1e-03, atol=1e-02)
def test_euclidean_point_cloud(self): rngs = jax.random.split(self.rng, 2) x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim)) y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim)) geometry_xx = pointcloud.PointCloud(x, x, epsilon=0.1) geometry_xy = pointcloud.PointCloud(x, y, epsilon=0.1) geometry_yy = pointcloud.PointCloud(y, y, epsilon=0.1) div = sinkhorn_divergence._sinkhorn_divergence(geometry_xy, geometry_xx, geometry_yy, self._a, self._b, threshold=1e-1, max_iterations=20) # div.divergence = 2.0 self.assertGreater(div.divergence, 0.0) self.assertLen(div.potentials, 3)
def loss_fn(x, y): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) f, g, regularized_transport_cost, _, _ = sinkhorn.sinkhorn( geom, a, b, lse_mode=lse_mode, implicit_differentiation=implicit_differentiation) return regularized_transport_cost, (geom, f, g)
def test_online_euclidean_point_cloud(self, lse_mode): """Testing the online way to handle geometry.""" threshold = 1e-3 geom = pointcloud.PointCloud( self.x, self.y, epsilon=0.1, online=True) errors = sinkhorn.sinkhorn( geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode).errors err = errors[errors > -1][-1] self.assertGreater(threshold, err)
def loss_fn(x, y): geom = pointcloud.PointCloud(x, y, epsilon=0.01) f, g, regularized_transport_cost, _, _ = sinkhorn.sinkhorn( geom, a, b, momentum_strategy=momentum_strategy, lse_mode=lse_mode) return regularized_transport_cost, (geom, f, g)
def test_apply_transport_geometry_from_scalings(self): """Applying transport matrix P on vector without instantiating P.""" n, m, d = 160, 230, 6 keys = jax.random.split(self.rng, 6) x = jax.random.uniform(keys[0], (n, d)) y = jax.random.uniform(keys[1], (m, d)) a = jax.random.uniform(keys[2], (n, )) b = jax.random.uniform(keys[3], (m, )) a = a / jnp.sum(a) b = b / jnp.sum(b) transport_t_vec_a = [None, None, None, None] transport_vec_b = [None, None, None, None] batch_b = 8 vec_a = jax.random.normal(keys[4], (n, )) vec_b = jax.random.normal(keys[5], (batch_b, m)) # test with lse_mode and online = True / False for j, lse_mode in enumerate([True, False]): for i, online in enumerate([True, False]): geom = pointcloud.PointCloud(x, y, online=online, epsilon=0.2) sink = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode) u = geom.scaling_from_potential(sink.f) v = geom.scaling_from_potential(sink.g) transport_t_vec_a[i + 2 * j] = geom.apply_transport_from_scalings( u, v, vec_a, axis=0) transport_vec_b[i + 2 * j] = geom.apply_transport_from_scalings( u, v, vec_b, axis=1) transport = geom.transport_from_scalings(u, v) self.assertAllClose(transport_t_vec_a[i + 2 * j], jnp.dot(transport.T, vec_a).T, rtol=1e-3, atol=1e-3) self.assertAllClose(transport_vec_b[i + 2 * j], jnp.dot(transport, vec_b.T).T, rtol=1e-3, atol=1e-3) self.assertIsNot( jnp.any(jnp.isnan(transport_t_vec_a[i + 2 * j])), True) for i in range(4): self.assertAllClose(transport_vec_b[i], transport_vec_b[0], rtol=1e-3, atol=1e-3) self.assertAllClose(transport_t_vec_a[i], transport_t_vec_a[0], rtol=1e-3, atol=1e-3)
def test_transport_wrong_init(self): rngs = jax.random.split(self.rng, 2) num_a, num_b = 23, 48 x = jax.random.uniform(rngs[0], (num_a, 4)) y = jax.random.uniform(rngs[1], (num_b, 4)) geom = pointcloud.PointCloud(x, y, epsilon=1e-3, online=True) with self.assertRaises(ValueError): transport.Transport(geom, x, threshold=1e-3) with self.assertRaises(AttributeError): transport.Transport(x, y, x, y, threshold=1e-3)
def loss_pcg(a, x, implicit=True): out = sinkhorn.sinkhorn(pointcloud.PointCloud(x, self.y, epsilon=epsilon), a=a, b=self.b, tau_a=1.0, tau_b=0.95, threshold=1e-4, lse_mode=lse_mode, implicit_differentiation=implicit) return out.reg_ot_cost
def test_transport_from_geom(self): rngs = jax.random.split(self.rng, 3) num_a, num_b = 23, 48 x = jax.random.uniform(rngs[0], (num_a, 4)) y = jax.random.uniform(rngs[1], (num_b, 4)) geom = pointcloud.PointCloud(x, y, epsilon=1e-3, online=True) b = jax.random.uniform(rngs[2], (num_b, )) b /= jnp.sum(b) ot = transport.Transport(geom, b=b, threshold=1e-3) self.assertEqual(ot.matrix.shape, (num_a, num_b)) self.assertAllClose(jnp.sum(ot.matrix, axis=1), ot.a, atol=1e-3) self.assertAllClose(jnp.sum(ot.matrix, axis=0), ot.b, atol=1e-3)
def test_euclidean_point_cloud(self, lse_mode): """Two point clouds, tested with various parameters.""" threshold = 1e-1 geom = pointcloud.PointCloud(self.x, self.y, epsilon=1, online=True) errors = sinkhorn.sinkhorn(geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode, implicit_differentiation=True).errors err = errors[errors > -1][-1] self.assertGreater(threshold, err)
def test_euclidean_point_cloud_min_iter(self): """Testing the min_iterations parameter.""" threshold = 1e-3 geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1) errors = sinkhorn.sinkhorn( geom, a=self.a, b=self.b, threshold=threshold, min_iterations=34, implicit_differentiation=False).errors err = errors[jnp.logical_and(errors > -1, jnp.isfinite(errors))][-1] self.assertGreater(threshold, err) self.assertEqual(jnp.inf, errors[0]) self.assertEqual(jnp.inf, errors[1]) self.assertEqual(jnp.inf, errors[2]) self.assertGreater(errors[3], 0)
def test_bures_point_cloud(self, lse_mode, online): """Two point clouds of Gaussians, tested with various parameters.""" threshold = 1e-3 geom = pointcloud.PointCloud(self.x, self.y, cost_fn=costs.Bures(dimension=self.dim, regularization=1e-4), online=online, epsilon=self.eps) errors = sinkhorn.sinkhorn(geom, a=self.a, b=self.b, lse_mode=lse_mode).errors err = errors[errors > -1][-1] self.assertGreater(threshold, err)
def test_euclidean_point_cloud(self, lse_mode, momentum_strategy, inner_iterations, norm_error): """Two point clouds, tested with various parameters.""" threshold = 1e-3 geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1) errors = sinkhorn.sinkhorn(geom, a=self.a, b=self.b, threshold=threshold, momentum_strategy=momentum_strategy, inner_iterations=inner_iterations, norm_error=norm_error, lse_mode=lse_mode).errors err = errors[errors > -1][-1] self.assertGreater(threshold, err)
def test_euclidean_point_cloud_parallel_weights(self, lse_mode): """Two point clouds, parallel execution for batched histograms.""" self.rng, *rngs = jax.random.split(self.rng, 2) batch = 4 a = jax.random.uniform(rngs[0], (batch, self.n)) b = jax.random.uniform(rngs[0], (batch, self.m)) a = a / jnp.sum(a, axis=1)[:, jnp.newaxis] b = b / jnp.sum(b, axis=1)[:, jnp.newaxis] threshold = 1e-3 geom = pointcloud.PointCloud( self.x, self.y, epsilon=0.1, online=True) errors = sinkhorn.sinkhorn( geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode).errors err = errors[errors > -1][-1] self.assertGreater(jnp.min(threshold - err), 0)
def test_apply_cost(self): grid_size = (5, 6, 7) geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1) x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]] xyz = jnp.stack([ jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1), jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1), jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1), ]).transpose() geom_mat = pointcloud.PointCloud(xyz, xyz, epsilon=0.1) vec = jax.random.uniform(self.rng, grid_size).ravel() self.assertAllClose(geom_mat.apply_cost(vec), geom_grid.apply_cost(vec)) self.assertAllClose( geom_grid.apply_cost(vec)[0, :], np.dot(geom_mat.cost_matrix, vec))
def test_grid_vs_euclidean(self, lse_mode): grid_size = (5, 6, 7) keys = jax.random.split(self.rng, 2) a = jax.random.uniform(keys[0], grid_size) b = jax.random.uniform(keys[1], grid_size) a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) epsilon = 0.1 geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon) x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]] xyz = jnp.stack([ jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1), jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1), jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1), ]).transpose() geometry_mat = pointcloud.PointCloud(xyz, xyz, epsilon=epsilon) out_mat = sinkhorn.sinkhorn(geometry_mat, a=a, b=b, lse_mode=lse_mode) out_grid = sinkhorn.sinkhorn(geometry_grid, a=a, b=b, lse_mode=lse_mode) self.assertAllClose(out_mat.reg_ot_cost, out_grid.reg_ot_cost)
def test_apply_transport_grid(self, lse_mode): grid_size = (5, 6, 7) keys = jax.random.split(self.rng, 3) a = jax.random.uniform(keys[0], grid_size) b = jax.random.uniform(keys[1], grid_size) a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1) x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]] xyz = jnp.stack([ jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1), jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1), jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1), ]).transpose() geom_mat = pointcloud.PointCloud(xyz, xyz, epsilon=0.1) sink_mat = sinkhorn.sinkhorn(geom_mat, a=a, b=b, lse_mode=lse_mode) sink_grid = sinkhorn.sinkhorn(geom_grid, a=a, b=b, lse_mode=lse_mode) batch_a = 3 batch_b = 4 vec_a = jax.random.normal(keys[4], [batch_a, np.prod(np.array(grid_size))]) vec_b = jax.random.normal(keys[4], [batch_b, np.prod(grid_size)]) vec_a = vec_a / jnp.sum(vec_a, axis=1)[:, jnp.newaxis] vec_b = vec_b / jnp.sum(vec_b, axis=1)[:, jnp.newaxis] mat_transport_t_vec_a = geom_mat.apply_transport_from_potentials( sink_mat.f, sink_mat.g, vec_a, axis=0) mat_transport_vec_b = geom_mat.apply_transport_from_potentials( sink_mat.f, sink_mat.g, vec_b, axis=1) grid_transport_t_vec_a = geom_grid.apply_transport_from_potentials( sink_grid.f, sink_grid.g, vec_a, axis=0) grid_transport_vec_b = geom_grid.apply_transport_from_potentials( sink_grid.f, sink_grid.g, vec_b, axis=1) self.assertAllClose(mat_transport_t_vec_a, grid_transport_t_vec_a) self.assertAllClose(mat_transport_vec_b, grid_transport_vec_b) self.assertIsNot(jnp.any(jnp.isnan(mat_transport_t_vec_a)), True)
def __init__(self, *args, a=None, b=None, **kwargs): """Initialization. Args: *args: can be either a single argument, the geometry.Geometry instance, or for convenience only two jnp.ndarray<float> corresponding to two point clouds. In that case the regularization parameter epsilon must be set in the kwargs. a: the weights of the source. b: the weights of the target. **kwargs: the keyword arguments passed to the sinkhorn algorithm. If the first argument is made of two arrays, kwargs must contain epsilon. Raises: A ValueError in the case the Geometry cannot be defined by the input parameters. """ if len(args) == 1: if not isinstance(args[0], geometry.Geometry): raise ValueError( 'A transport problem must be defined by either a ' 'single geometry, or two arrays.') self.geom = args[0] else: pc_kw = {} for key in ['epsilon', 'cost_fn', 'power', 'online']: value = kwargs.pop(key, None) if value is not None: pc_kw[key] = value self.geom = pointcloud.PointCloud(*args, **pc_kw) num_a, num_b = self.geom.shape self.a = jnp.ones((num_a, )) / num_a if a is None else a self.b = jnp.ones((num_b, )) / num_b if b is None else b self._f = None self._g = None self._kwargs = kwargs self.reg_ot_cost = None self.solve()
def test_restart(self, lse_mode): """Two point clouds, tested with various parameters.""" threshold = 1e-4 geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.01) out = sinkhorn.sinkhorn(geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode, inner_iterations=1) errors = out.errors err = errors[errors > -1][-1] self.assertGreater(threshold, err) # recover solution from previous and ensure faster convergence. if lse_mode: init_dual_a, init_dual_b = out.f, out.g else: init_dual_a, init_dual_b = (geom.scaling_from_potential(out.f), geom.scaling_from_potential(out.g)) out_restarted = sinkhorn.sinkhorn(geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode, init_dual_a=init_dual_a, init_dual_b=init_dual_b, inner_iterations=1) errors_restarted = out_restarted.errors err_restarted = errors_restarted[errors_restarted > -1][-1] self.assertGreater(threshold, err_restarted) num_iter_restarted = jnp.sum(errors_restarted > -1) # check we can only improve on error self.assertGreater(err, err_restarted) # check first error in restart does at least as well as previous best self.assertGreater(err, errors_restarted[0]) # check only one iteration suffices when restarting with same data. self.assertEqual(num_iter_restarted, 1)
def test_apply_cost_and_kernel(self): """Test consistency of cost/kernel apply to vec.""" n, m, p, b = 5, 8, 10, 7 keys = jax.random.split(self.rng, 5) x = jax.random.normal(keys[0], (n, p)) y = jax.random.normal(keys[1], (m, p)) + 1 cost = jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1) vec0 = jax.random.normal(keys[2], (n, b)) vec1 = jax.random.normal(keys[3], (m, b)) geom = pointcloud.PointCloud(x, y, power=2, online=True) prod0_online = geom.apply_cost(vec0, axis=0) prod1_online = geom.apply_cost(vec1, axis=1) geom = pointcloud.PointCloud(x, y, power=2, online=False) prod0 = geom.apply_cost(vec0, axis=0) prod1 = geom.apply_cost(vec1, axis=1) geom = geometry.Geometry(cost) prod0_geom = geom.apply_cost(vec0, axis=0) prod1_geom = geom.apply_cost(vec1, axis=1) self.assertAllClose(prod0_online, prod0, rtol=1e-03, atol=1e-02) self.assertAllClose(prod1_online, prod1, rtol=1e-03, atol=1e-02) self.assertAllClose(prod0_geom, prod0, rtol=1e-03, atol=1e-02) self.assertAllClose(prod1_geom, prod1, rtol=1e-03, atol=1e-02) geom = pointcloud.PointCloud(x, y, power=1, online=True) prod0_online = geom.apply_cost(vec0, axis=0) prod1_online = geom.apply_cost(vec1, axis=1) geom = pointcloud.PointCloud(x, y, power=1, online=False) prod0 = geom.apply_cost(vec0, axis=0) prod1 = geom.apply_cost(vec1, axis=1) self.assertAllClose(prod0_online, prod0, rtol=1e-03, atol=1e-02) self.assertAllClose(prod1_online, prod1, rtol=1e-03, atol=1e-02) geom = pointcloud.PointCloud(x, y, power=2, online=True) prod0_online = geom.apply_kernel(vec0, 1., axis=0) prod1_online = geom.apply_kernel(vec1, 1., axis=1) geom = pointcloud.PointCloud(x, y, power=2, online=False) prod0 = geom.apply_kernel(vec0, 1., axis=0) prod1 = geom.apply_kernel(vec1, 1., axis=1) self.assertAllClose(prod0_online, prod0, rtol=1e-03, atol=1e-02) self.assertAllClose(prod1_online, prod1, rtol=1e-03, atol=1e-02)