def test_geom_vs_point_cloud(self): """Two point clouds vs. simple cost_matrix execution of sinkorn.""" geom = pointcloud.PointCloud(self.x, self.y) geom_2 = geometry.Geometry(geom.cost_matrix) f = sinkhorn.sinkhorn(geom, a=self.a, b=self.b).f f_2 = sinkhorn.sinkhorn(geom_2, a=self.a, b=self.b).f self.assertAllClose(f, f_2)
def test_online_vs_batch_euclidean_point_cloud(self, lse_mode): """Comparing online vs batch geometry.""" threshold = 1e-3 eps = 0.1 online_geom = pointcloud.PointCloud(self.x, self.y, epsilon=eps, online=True) online_geom_euc = pointcloud.PointCloud(self.x, self.y, cost_fn=costs.Euclidean(), epsilon=eps, online=True) batch_geom = pointcloud.PointCloud(self.x, self.y, epsilon=eps) batch_geom_euc = pointcloud.PointCloud(self.x, self.y, cost_fn=costs.Euclidean(), epsilon=eps) out_online = sinkhorn.sinkhorn(online_geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode) out_batch = sinkhorn.sinkhorn(batch_geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode) out_online_euc = sinkhorn.sinkhorn(online_geom_euc, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode) out_batch_euc = sinkhorn.sinkhorn(batch_geom_euc, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode) # Checks regularized transport costs match. self.assertAllClose(out_online.reg_ot_cost, out_batch.reg_ot_cost) # check regularized transport matrices match self.assertAllClose( online_geom.transport_from_potentials(out_online.f, out_online.g), batch_geom.transport_from_potentials(out_batch.f, out_batch.g)) self.assertAllClose( online_geom_euc.transport_from_potentials(out_online_euc.f, out_online_euc.g), batch_geom_euc.transport_from_potentials(out_batch_euc.f, out_batch_euc.g)) self.assertAllClose( batch_geom.transport_from_potentials(out_batch.f, out_batch.g), batch_geom_euc.transport_from_potentials(out_batch_euc.f, out_batch_euc.g))
def sinkhorn_for_sort(inputs: jnp.ndarray, weights: jnp.ndarray, target_weights: jnp.ndarray, sinkhorn_kw, pointcloud_kw) -> jnp.ndarray: """Runs sinkhorn on a fixed increasing target. Args: inputs: jnp.ndarray[num_points]. Must be one dimensional. weights: jnp.ndarray[num_points]. The weights 'a' for the inputs. target_weights: jnp.ndarray[num_targets]: the weights of the targets. It may be of a different size than the weights. sinkhorn_kw: a dictionary holding the sinkhorn keyword arguments. See sinkhorn.py for more details. pointcloud_kw: a dictionary holding the keyword arguments of the PointCloud class. See pointcloud.py for more details. Returns: A jnp.ndarray<float> representing the transport matrix of the inputs onto the underlying sorted target. """ shape = inputs.shape if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1): raise ValueError( "Shape ({shape}) not supported. The input should be one-dimensional." ) x = jnp.expand_dims(jnp.squeeze(inputs), axis=1) x = jax.nn.sigmoid((x - jnp.mean(x)) / (jnp.std(x) + 1e-10)) a = jnp.squeeze(weights) b = jnp.squeeze(target_weights) num_targets = b.shape[0] y = jnp.linspace(0.0, 1.0, num_targets)[:, jnp.newaxis] geom = pointcloud.PointCloud(x, y, **pointcloud_kw) res = sinkhorn.sinkhorn(geom, a, b, **sinkhorn_kw) return geom.transport_from_potentials(res.f, res.g)
def reg_ot(x): geom = grid.Grid(x=x, epsilon=1.0) return sinkhorn.sinkhorn(geom, a=a, b=b, threshold=0.1, lse_mode=lse_mode).reg_ot_cost
def solve(self): """Runs the sinkhorn algorithm to solve the transport problem.""" out = sinkhorn.sinkhorn(self.geom, self.a, self.b, **self._kwargs) # TODO(oliviert): figure out how to warn the user if no convergence. # So far we always set the values, even if not converged. # TODO(oliviert, cuturi): handles cases where it has not converged. self._f = out.f self._g = out.g self.reg_ot_cost = out.reg_ot_cost
def loss_fn(x, y): geom = pointcloud.PointCloud(x, y, epsilon=epsilon) f, g, regularized_transport_cost, _, _ = sinkhorn.sinkhorn( geom, a, b, lse_mode=lse_mode, implicit_differentiation=implicit_differentiation) return regularized_transport_cost, (geom, f, g)
def test_online_euclidean_point_cloud(self, lse_mode): """Testing the online way to handle geometry.""" threshold = 1e-3 geom = pointcloud.PointCloud( self.x, self.y, epsilon=0.1, online=True) errors = sinkhorn.sinkhorn( geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode).errors err = errors[errors > -1][-1] self.assertGreater(threshold, err)
def loss_fn(x, y): geom = pointcloud.PointCloud(x, y, epsilon=0.01) f, g, regularized_transport_cost, _, _ = sinkhorn.sinkhorn( geom, a, b, momentum_strategy=momentum_strategy, lse_mode=lse_mode) return regularized_transport_cost, (geom, f, g)
def _sinkhorn_divergence(geometry_xy: geometry.Geometry, geometry_xx: geometry.Geometry, geometry_yy: Optional[geometry.Geometry], a: jnp.ndarray, b: jnp.ndarray, **kwargs): """Computes the (unbalanced) sinkhorn divergence for the wrapper function. This definition includes a correction depending on the total masses of each measure, as defined in https://arxiv.org/pdf/1910.12958.pdf (15). Args: geometry_xy: a Cost object able to apply kernels with a certain epsilon, between the views X and Y. geometry_xx: a Cost object able to apply kernels with a certain epsilon, between elements of the view X. geometry_yy: a Cost object able to apply kernels with a certain epsilon, between elements of the view Y. a: jnp.ndarray<float>[n]: the weight of each input point. The sum of all elements of b must match that of a to converge. b: jnp.ndarray<float>[m]: the weight of each target point. The sum of all elements of b must match that of a to converge. **kwargs: Arguments to sinkhorn. Returns: SinkhornDivergenceOutput named tuple. """ # Replaces parallel/momentum arguments in symmetric case. kwargs_symmetric = kwargs.copy() kwargs_symmetric.update(parallel_dual_updates=True, momentum_strategy=0.5) out_xy = sinkhorn.sinkhorn(geometry_xy, a, b, **kwargs) out_xx = sinkhorn.sinkhorn(geometry_xx, a, a, **kwargs_symmetric) if geometry_yy is None: out_yy = sinkhorn.SinkhornOutput(None, None, 0, None, None) else: out_yy = sinkhorn.sinkhorn(geometry_yy, b, b, **kwargs_symmetric) div = (out_xy.reg_ot_cost - 0.5 * (out_xx.reg_ot_cost + out_yy.reg_ot_cost) + 0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b))**2) out = (out_xy, out_xx, out_yy) return SinkhornDivergenceOutput(div, tuple([s.f, s.g] for s in out), (geometry_xy, geometry_xx, geometry_yy), tuple(s.errors for s in out), tuple(s.converged for s in out))
def test_grid_vs_euclidean(self, lse_mode): grid_size = (5, 6, 7) keys = jax.random.split(self.rng, 2) a = jax.random.uniform(keys[0], grid_size) b = jax.random.uniform(keys[1], grid_size) a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) epsilon = 0.1 geometry_grid = grid.Grid(grid_size=grid_size, epsilon=epsilon) x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]] xyz = jnp.stack([ jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1), jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1), jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1), ]).transpose() geometry_mat = pointcloud.PointCloud(xyz, xyz, epsilon=epsilon) out_mat = sinkhorn.sinkhorn(geometry_mat, a=a, b=b, lse_mode=lse_mode) out_grid = sinkhorn.sinkhorn(geometry_grid, a=a, b=b, lse_mode=lse_mode) self.assertAllClose(out_mat.reg_ot_cost, out_grid.reg_ot_cost)
def test_apply_transport_geometry_from_scalings(self): """Applying transport matrix P on vector without instantiating P.""" n, m, d = 160, 230, 6 keys = jax.random.split(self.rng, 6) x = jax.random.uniform(keys[0], (n, d)) y = jax.random.uniform(keys[1], (m, d)) a = jax.random.uniform(keys[2], (n, )) b = jax.random.uniform(keys[3], (m, )) a = a / jnp.sum(a) b = b / jnp.sum(b) transport_t_vec_a = [None, None, None, None] transport_vec_b = [None, None, None, None] batch_b = 8 vec_a = jax.random.normal(keys[4], (n, )) vec_b = jax.random.normal(keys[5], (batch_b, m)) # test with lse_mode and online = True / False for j, lse_mode in enumerate([True, False]): for i, online in enumerate([True, False]): geom = pointcloud.PointCloud(x, y, online=online, epsilon=0.2) sink = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode) u = geom.scaling_from_potential(sink.f) v = geom.scaling_from_potential(sink.g) transport_t_vec_a[i + 2 * j] = geom.apply_transport_from_scalings( u, v, vec_a, axis=0) transport_vec_b[i + 2 * j] = geom.apply_transport_from_scalings( u, v, vec_b, axis=1) transport = geom.transport_from_scalings(u, v) self.assertAllClose(transport_t_vec_a[i + 2 * j], jnp.dot(transport.T, vec_a).T, rtol=1e-3, atol=1e-3) self.assertAllClose(transport_vec_b[i + 2 * j], jnp.dot(transport, vec_b.T).T, rtol=1e-3, atol=1e-3) self.assertIsNot( jnp.any(jnp.isnan(transport_t_vec_a[i + 2 * j])), True) for i in range(4): self.assertAllClose(transport_vec_b[i], transport_vec_b[0], rtol=1e-3, atol=1e-3) self.assertAllClose(transport_t_vec_a[i], transport_t_vec_a[0], rtol=1e-3, atol=1e-3)
def test_apply_transport_grid(self, lse_mode): grid_size = (5, 6, 7) keys = jax.random.split(self.rng, 3) a = jax.random.uniform(keys[0], grid_size) b = jax.random.uniform(keys[1], grid_size) a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1) x, y, z = np.mgrid[0:grid_size[0], 0:grid_size[1], 0:grid_size[2]] xyz = jnp.stack([ jnp.array(x.ravel()) / jnp.maximum(1, grid_size[0] - 1), jnp.array(y.ravel()) / jnp.maximum(1, grid_size[1] - 1), jnp.array(z.ravel()) / jnp.maximum(1, grid_size[2] - 1), ]).transpose() geom_mat = pointcloud.PointCloud(xyz, xyz, epsilon=0.1) sink_mat = sinkhorn.sinkhorn(geom_mat, a=a, b=b, lse_mode=lse_mode) sink_grid = sinkhorn.sinkhorn(geom_grid, a=a, b=b, lse_mode=lse_mode) batch_a = 3 batch_b = 4 vec_a = jax.random.normal(keys[4], [batch_a, np.prod(np.array(grid_size))]) vec_b = jax.random.normal(keys[4], [batch_b, np.prod(grid_size)]) vec_a = vec_a / jnp.sum(vec_a, axis=1)[:, jnp.newaxis] vec_b = vec_b / jnp.sum(vec_b, axis=1)[:, jnp.newaxis] mat_transport_t_vec_a = geom_mat.apply_transport_from_potentials( sink_mat.f, sink_mat.g, vec_a, axis=0) mat_transport_vec_b = geom_mat.apply_transport_from_potentials( sink_mat.f, sink_mat.g, vec_b, axis=1) grid_transport_t_vec_a = geom_grid.apply_transport_from_potentials( sink_grid.f, sink_grid.g, vec_a, axis=0) grid_transport_vec_b = geom_grid.apply_transport_from_potentials( sink_grid.f, sink_grid.g, vec_b, axis=1) self.assertAllClose(mat_transport_t_vec_a, grid_transport_t_vec_a) self.assertAllClose(mat_transport_vec_b, grid_transport_vec_b) self.assertIsNot(jnp.any(jnp.isnan(mat_transport_t_vec_a)), True)
def test_euclidean_point_cloud(self, lse_mode): """Two point clouds, tested with various parameters.""" threshold = 1e-1 geom = pointcloud.PointCloud(self.x, self.y, epsilon=1, online=True) errors = sinkhorn.sinkhorn(geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode, implicit_differentiation=True).errors err = errors[errors > -1][-1] self.assertGreater(threshold, err)
def loss_pcg(a, x, implicit=True): out = sinkhorn.sinkhorn(pointcloud.PointCloud(x, self.y, epsilon=epsilon), a=a, b=self.b, tau_a=1.0, tau_b=0.95, threshold=1e-4, lse_mode=lse_mode, implicit_differentiation=implicit) return out.reg_ot_cost
def test_euclidean_point_cloud_min_iter(self): """Testing the min_iterations parameter.""" threshold = 1e-3 geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1) errors = sinkhorn.sinkhorn( geom, a=self.a, b=self.b, threshold=threshold, min_iterations=34, implicit_differentiation=False).errors err = errors[jnp.logical_and(errors > -1, jnp.isfinite(errors))][-1] self.assertGreater(threshold, err) self.assertEqual(jnp.inf, errors[0]) self.assertEqual(jnp.inf, errors[1]) self.assertEqual(jnp.inf, errors[2]) self.assertGreater(errors[3], 0)
def test_restart(self, lse_mode): """Two point clouds, tested with various parameters.""" threshold = 1e-4 geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.01) out = sinkhorn.sinkhorn(geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode, inner_iterations=1) errors = out.errors err = errors[errors > -1][-1] self.assertGreater(threshold, err) # recover solution from previous and ensure faster convergence. if lse_mode: init_dual_a, init_dual_b = out.f, out.g else: init_dual_a, init_dual_b = (geom.scaling_from_potential(out.f), geom.scaling_from_potential(out.g)) out_restarted = sinkhorn.sinkhorn(geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode, init_dual_a=init_dual_a, init_dual_b=init_dual_b, inner_iterations=1) errors_restarted = out_restarted.errors err_restarted = errors_restarted[errors_restarted > -1][-1] self.assertGreater(threshold, err_restarted) num_iter_restarted = jnp.sum(errors_restarted > -1) # check we can only improve on error self.assertGreater(err, err_restarted) # check first error in restart does at least as well as previous best self.assertGreater(err, errors_restarted[0]) # check only one iteration suffices when restarting with same data. self.assertEqual(num_iter_restarted, 1)
def test_bures_point_cloud(self, lse_mode, online): """Two point clouds of Gaussians, tested with various parameters.""" threshold = 1e-3 geom = pointcloud.PointCloud(self.x, self.y, cost_fn=costs.Bures(dimension=self.dim, regularization=1e-4), online=online, epsilon=self.eps) errors = sinkhorn.sinkhorn(geom, a=self.a, b=self.b, lse_mode=lse_mode).errors err = errors[errors > -1][-1] self.assertGreater(threshold, err)
def loss_g(a, x, implicit=True): out = sinkhorn.sinkhorn(geometry.Geometry( cost_matrix=jnp.sum(x**2, axis=1)[:, jnp.newaxis] + jnp.sum(self.y**2, axis=1)[jnp.newaxis, :] - 2 * jnp.dot(x, self.y.T), epsilon=epsilon), a=a, b=self.b, tau_a=0.8, tau_b=0.87, threshold=1e-4, lse_mode=lse_mode, implicit_differentiation=implicit) return out.reg_ot_cost
def potential(a, x, implicit, d): out = sinkhorn.sinkhorn(geometry.Geometry( cost_matrix=jnp.sum(x**2, axis=1)[:, jnp.newaxis] + jnp.sum(y**2, axis=1)[jnp.newaxis, :] - 2 * jnp.dot(x, y.T), epsilon=epsilon), a=a, b=b, tau_a=tau_a, tau_b=tau_b, lse_mode=lse_mode, threshold=1e-4, implicit_differentiation=implicit, inner_iterations=2) return jnp.sum(out.f * d)
def test_euclidean_point_cloud(self, lse_mode, momentum_strategy, inner_iterations, norm_error): """Two point clouds, tested with various parameters.""" threshold = 1e-3 geom = pointcloud.PointCloud(self.x, self.y, epsilon=0.1) errors = sinkhorn.sinkhorn(geom, a=self.a, b=self.b, threshold=threshold, momentum_strategy=momentum_strategy, inner_iterations=inner_iterations, norm_error=norm_error, lse_mode=lse_mode).errors err = errors[errors > -1][-1] self.assertGreater(threshold, err)
def solve(self): """Runs the sinkhorn algorithm to solve the transport problem.""" out = sinkhorn.sinkhorn(self.geom, self.a, self.b, **self._kwargs) if not out.converged: # TODO(oliviert): Point to the online doc when available. logging.warning( 'Sinkhorn has not converged. Please check your setup and ' ' consider increasing max_iterations or epsilon. For more' ' details, see ott.core.sinkhorn.sinkhorn.') # So far we always set the values, even if not converged. # TODO(oliviert, cuturi): handles cases where it has not converged. self._f = out.f self._g = out.g self.reg_ot_cost = out.reg_ot_cost
def test_euclidean_point_cloud_parallel_weights(self, lse_mode): """Two point clouds, parallel execution for batched histograms.""" self.rng, *rngs = jax.random.split(self.rng, 2) batch = 4 a = jax.random.uniform(rngs[0], (batch, self.n)) b = jax.random.uniform(rngs[0], (batch, self.m)) a = a / jnp.sum(a, axis=1)[:, jnp.newaxis] b = b / jnp.sum(b, axis=1)[:, jnp.newaxis] threshold = 1e-3 geom = pointcloud.PointCloud( self.x, self.y, epsilon=0.1, online=True) errors = sinkhorn.sinkhorn( geom, a=self.a, b=self.b, threshold=threshold, lse_mode=lse_mode).errors err = errors[errors > -1][-1] self.assertGreater(jnp.min(threshold - err), 0)
def test_jit_vs_non_jit_fwd(self): jitted_result = sinkhorn.sinkhorn(self.geometry, self.a, self.b) non_jitted_result = non_jitted_sinkhorn(self.geometry, self.a, self.b) def f(g, a, b): return non_jitted_sinkhorn(g, a, b) user_jitted_result = jax.jit(f)(self.geometry, self.a, self.b) chex.assert_tree_all_close(jitted_result, non_jitted_result, atol=1e-6, rtol=0) chex.assert_tree_all_close(jitted_result, user_jitted_result, atol=1e-6, rtol=0)
def test_separable_grid(self, lse_mode): """Two histograms in a grid of size 5 x 6 x 7 in the hypercube^3.""" grid_size = (5, 6, 7) keys = jax.random.split(self.rng, 2) a = jax.random.uniform(keys[0], grid_size) b = jax.random.uniform(keys[1], grid_size) # adding zero weights to test proper handling, then ravel. a = jax.ops.index_update(a, 0, 0).ravel() a = a / jnp.sum(a) b = jax.ops.index_update(b, 3, 0).ravel() b = b / jnp.sum(b) threshold = 0.01 geom = grid.Grid(grid_size=grid_size, epsilon=0.1) errors = sinkhorn.sinkhorn( geom, a=a, b=b, threshold=threshold, lse_mode=lse_mode).errors err = errors[jnp.isfinite(errors)][-1] self.assertGreater(threshold, err)
def _sinkhorn_divergence(geometry_xy: geometry.Geometry, geometry_xx: geometry.Geometry, geometry_yy: Optional[geometry.Geometry], a: jnp.ndarray, b: jnp.ndarray, **kwargs): """Computes the (unbalanced) sinkhorn divergence for the wrapper function. This definition includes a correction depending on the total masses of each measure, as defined in https://arxiv.org/pdf/1910.12958.pdf (15). Args: geometry_xy: a Cost object able to apply kernels with a certain epsilon, between the views X and Y. geometry_xx: a Cost object able to apply kernels with a certain epsilon, between elements of the view X. geometry_yy: a Cost object able to apply kernels with a certain epsilon, between elements of the view Y. a: jnp.ndarray<float>[n]: the weight of each input point. The sum of all elements of b must match that of a to converge. b: jnp.ndarray<float>[m]: the weight of each target point. The sum of all elements of b must match that of a to converge. **kwargs: Arguments to sinkhorn_iterations. Returns: SinkhornDivergenceOutput named tuple. """ geoms = (geometry_xy, geometry_xx, geometry_yy) out = [ sinkhorn.SinkhornOutput(None, None, 0, None, None) if geom is None else sinkhorn.sinkhorn(geom, marginals[0], marginals[1], **kwargs) for (geom, marginals) in zip(geoms, [[a, b], [a, a], [b, b]]) ] div = (out[0].reg_ot_cost - 0.5 * (out[1].reg_ot_cost + out[2].reg_ot_cost) + 0.5 * geometry_xy.epsilon * (jnp.sum(a) - jnp.sum(b))**2) return SinkhornDivergenceOutput(div, tuple([s.f, s.g] for s in out), geoms, tuple(s.errors for s in out), tuple(s.converged for s in out))
def loss_fn(cm): a = jnp.ones(cm.shape[0]) / cm.shape[0] b = jnp.ones(cm.shape[1]) / cm.shape[1] geom = geometry.Geometry(cm, epsilon=0.5) out = sinkhorn.sinkhorn(geom, a, b, lse_mode=lse_mode) return out.reg_ot_cost, (geom, out.f, out.g)
def reg_ot(a, b): return sinkhorn.sinkhorn( geom, a=a, b=b, threshold=0.1, lse_mode=lse_mode, implicit_differentiation=False).reg_ot_cost
def _gw_iterations( geom_x: geometry.Geometry, geom_y: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, epsilon: Union[epsilon_scheduler.Epsilon, float], loss: GWLoss, max_iterations: int, warm_start: bool, sinkhorn_kwargs: Optional[Dict[str, Any]], **kwargs ) -> Tuple[jnp.ndarray, jnp.ndarray, geometry.Geometry, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Fits Gromov Wasserstein. Args: geom_x: a Geometry object for the first view. geom_y: a second Geometry object for the second view. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. epsilon: a regularization parameter or a epsilon_scheduler.Epsilon object. loss: GWLoss object. max_iterations: int, the maximum number of outer iterations for Gromov Wasserstein. warm_start: bool, optional initialisation of the potentials/scalings w.r.t. first and second marginals between each call to sinkhorn. sinkhorn_kwargs: Optionally a dictionary containing the keywords arguments for calls to the sinkhorn function. **kwargs: additional kwargs for epsilon. Returns: f: potential. g: potential. geom_gw: a Geometry object for Gromov-Wasserstein (GW). reg_gw_cost_arr: ndarray of regularised GW costs. errors_sinkhorn: ndarray [max_iterations, p], where p depends on sinkhorn_kwargs, of errors for the Sinkhorn algorithm for each gromov iteration (axis 0) and regularly spaced sinkhorn iterations (axis 1). converged_sinkhorn: ndarray [max_iterations,] of flags indicating that the sinkhorn algorithm converged. """ lse_mode = sinkhorn_kwargs.get('lse_mode', True) geom_gw = _init_geometry_gw(geom_x, geom_y, jax.lax.stop_gradient(a), jax.lax.stop_gradient(b), epsilon, loss, **kwargs) f, g, reg_gw_cost, errors_sinkhorn, converged_sinkhorn = sinkhorn.sinkhorn( geom_gw, a, b, **sinkhorn_kwargs) carry = geom_gw, f, g update_geom_partial = functools.partial(_update_geometry_gw, geom_x=geom_x, geom_y=geom_y, loss=loss, **kwargs) sinkhorn_partial = functools.partial(sinkhorn.sinkhorn, a=a, b=b, **sinkhorn_kwargs) def body_fn(carry=carry, x=None): del x geom_gw, f, g = carry geom_gw = update_geom_partial(geom=geom_gw, f=f, g=g) init_dual_a = ((f if lse_mode else geom_gw.scaling_from_potential(f)) if warm_start else None) init_dual_b = ((g if lse_mode else geom_gw.scaling_from_potential(g)) if warm_start else None) f, g, reg_gw_cost, errors_sinkhorn, converged_sinkhorn = sinkhorn_partial( geom=geom_gw, init_dual_a=init_dual_a, init_dual_b=init_dual_b) return (geom_gw, f, g), (reg_gw_cost, errors_sinkhorn, converged_sinkhorn) carry, out = jax.lax.scan(f=body_fn, init=carry, xs=None, length=max_iterations - 1) geom_gw, f, g = carry reg_gw_cost_arr = jnp.concatenate((jnp.array([reg_gw_cost]), out[0])) errors_sinkhorn = jnp.concatenate((jnp.array([errors_sinkhorn]), out[1])) converged_sinkhorn = jnp.concatenate( (jnp.array([converged_sinkhorn]), out[2])) return (f, g, geom_gw, reg_gw_cost_arr, errors_sinkhorn, converged_sinkhorn)
def reg_ot(a, b): return sinkhorn.sinkhorn(geom, a=a, b=b, threshold=0.001, lse_mode=lse_mode).reg_ot_cost
def reg_ot(a, b): return sinkhorn.sinkhorn(geom, a=a, b=b, lse_mode=lse_mode).reg_ot_cost