def _init_geometry_gw(geom_x: geometry.Geometry, geom_y: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, epsilon: Union[epsilon_scheduler.Epsilon, float], loss: GWLoss, **kwargs) -> geometry.Geometry: """Initialises the cost matrix for the geometry object for GW. The equation follows Equation 6, Proposition 1 of http://proceedings.mlr.press/v48/peyre16.pdf. Args: geom_x: a Geometry object for the first view. geom_y: a second Geometry object for the second view. a: jnp.ndarray<float>[num_a,], weights. b: jnp.ndarray<float>[num_b,], weights. epsilon: a regularization parameter or a epsilon_scheduler.Epsilon object. loss: a GWLossFn object. **kwargs: additional kwargs to epsilon. Returns: A Geometry object for Gromov-Wasserstein. """ # Initialization of the transport matrix in the balanced case, following # http://proceedings.mlr.press/v48/peyre16.pdf ab = a[:, None] * b[None, :] marginal_x = ab.sum(1) marginal_y = ab.sum(0) marginal_dep_term = _marginal_dependent_cost(marginal_x, marginal_y, geom_x, geom_y, loss) tmp = geom_x.apply_cost(ab, axis=1, fn=loss.left_x) cost_matrix = marginal_dep_term - geom_y.apply_cost( tmp.T, axis=1, fn=loss.right_y).T return geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon, **kwargs)
def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray: """Computes objective of regularized OT given dual solutions f,g.""" # In all sums below, jnp.where handle situations in which some coordinates of # a and b are zero. For those coordinates, their potential is -inf. # This leads to -inf - -inf or -inf x 0 operations which result in NaN. # These contributions are discarded when computing the objective. if tau_a == 1.0: div_a = jnp.sum( jnp.where(a > 0, (f - geom.potential_from_scaling(a)) * a, 0)) else: rho_a = geom.epsilon * (tau_a / (1 - tau_a)) div_a = jnp.sum( jnp.where( a > 0, a * (rho_a - (rho_a + geom.epsilon / 2) * jnp.exp(-(f - geom.potential_from_scaling(a)) / rho_a)), 0)) if tau_b == 1.0: div_b = jnp.sum( jnp.where(b > 0, (g - geom.potential_from_scaling(b)) * b, 0)) else: rho_b = geom.epsilon * (tau_b / (1 - tau_b)) div_b = jnp.sum( jnp.where( b > 0, b * (rho_b - (rho_b + geom.epsilon / 2) * jnp.exp(-(g - geom.potential_from_scaling(b)) / rho_b)), 0)) # Using https://arxiv.org/pdf/1910.12958.pdf Eq. 30 return div_a + div_b + geom.epsilon * jnp.sum(a) * jnp.sum(b)
def marginal_error(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f_u: jnp.ndarray, g_v: jnp.ndarray, norm_error: int, lse_mode) -> jnp.ndarray: """Conputes marginal error, the stopping criterion used to terminate Sinkhorn. Args: geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. f_u: jnp.ndarray, potential or scaling g_v: jnp.ndarray, potential or scaling norm_error: int, p-norm used to compute error. lse_mode: True if log-sum-exp operations, False if kernel vector producs. Returns: a positive number quantifying how far from convergence the algorithm stands. """ if tau_b == 1.0: err = geom.error(f_u, g_v, b, 0, norm_error, lse_mode) elif tau_a == 1.0: err = geom.error(f_u, g_v, a, 1, norm_error, lse_mode) else: # In the unbalanced case, we compute the norm of the gradient. # the gradient is equal to the marginal of the current plan minus # the gradient of < z, rho_z(exp^(-h/rho_z) -1> where z is either a or b # and h is either f or g. Note this is equal to z if rho_z → inf, which # is the case when tau_z → 1.0 if lse_mode: target = grad_of_marginal_fit(a, b, f_u, g_v, tau_a, tau_b, geom) else: target = grad_of_marginal_fit(a, b, geom.potential_from_scaling(f_u), geom.potential_from_scaling(g_v), tau_a, tau_b, geom) err = geom.error(f_u, g_v, target[0], 1, norm_error, lse_mode) err += geom.error(f_u, g_v, target[1], 0, norm_error, lse_mode) return err
def discrete_barycenter(geom: geometry.Geometry, a: jnp.ndarray, weights: jnp.ndarray = None, dual_initialization: jnp.ndarray = None, threshold: float = 1e-2, norm_error: int = 1, inner_iterations: float = 10, min_iterations: int = 0, max_iterations: int = 2000, lse_mode: bool = True, debiased: bool = False) -> SinkhornBarycenterOutput: """Compute discrete barycenter using https://arxiv.org/abs/2006.02575. Args: geom: a Cost object able to apply kernels with a certain epsilon. a: jnp.ndarray<float>[batch, geom.num_a]: batch of histograms. weights: jnp.ndarray of weights in the probability simplex dual_initialization: jnp.ndarray, size [batch, num_b] initialization for g_v threshold: (float) tolerance to monitor convergence. norm_error: int, power used to define p-norm of error for marginal/target. inner_iterations: (int32) the Sinkhorn error is not recomputed at each iteration but every inner_num_iter instead to avoid computational overhead. min_iterations: (int32) the minimum number of Sinkhorn iterations carried out before the error is computed and monitored. max_iterations: (int32) the maximum number of Sinkhorn iterations. lse_mode: True for log-sum-exp computations, False for kernel multiply. debiased: whether to run the debiased version of the Sinkhorn divergence. Returns: A ``SinkhornBarycenterOutput``, which contains two arrays of potentials, each of size ``batch`` times ``geom.num_a``, summarizing the OT between each histogram in the database onto the barycenter, described in ``histogram``, as well as a sequence of errors that monitors convergence. """ batch_size, num_a = a.shape _, num_b = geom.shape if weights is None: weights = jnp.ones((batch_size,)) / batch_size if not jnp.alltrue(weights > 0) or weights.shape[0] != batch_size: raise ValueError(f'weights must have positive values and size {batch_size}') if dual_initialization is None: # initialization strategy from https://arxiv.org/pdf/1503.02533.pdf, (3.6) dual_initialization = geom.apply_cost(a.T, axis=0).T dual_initialization -= jnp.average(dual_initialization, weights=weights, axis=0)[jnp.newaxis, :] if debiased and not geom.is_symmetric: raise ValueError('Geometry must be symmetric to use debiased option.') norm_error = (norm_error,) return _discrete_barycenter(geom, a, weights, dual_initialization, threshold, norm_error, inner_iterations, min_iterations, max_iterations, lse_mode, debiased, num_a, num_b)
def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool) -> jnp.ndarray: r"""Computes objective of regularized OT given dual solutions ``f``, ``g``. The objective is evaluated for dual solution ``f`` and ``g``, using inputs ``geom``, ``a`` and ``b``, in addition to parameters ``tau_a``, ``tau_b``. Situations where ``a`` or ``b`` have zero coordinates are reflected in minus infinity entries in their corresponding dual potentials. To avoid NaN that may result when multiplying 0's by infinity values, ``jnp.where`` is used to cancel these contributions. Args: geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. f: jnp.ndarray, potential g: jnp.ndarray, potential lse_mode: bool, whether to compute total mass in lse or kernel mode. Returns: a float, the regularized transport cost. """ supp_a = a > 0 supp_b = b > 0 if tau_a == 1.0: div_a = jnp.sum( jnp.where(supp_a, a * (f - geom.potential_from_scaling(a)), 0.0)) else: rho_a = geom.epsilon * (tau_a / (1 - tau_a)) div_a = - jnp.sum(jnp.where( supp_a, a * phi_star(-(f - geom.potential_from_scaling(a)), rho_a), 0.0)) if tau_b == 1.0: div_b = jnp.sum( jnp.where(supp_b, b * (g - geom.potential_from_scaling(b)), 0.0)) else: rho_b = geom.epsilon * (tau_b / (1 - tau_b)) div_b = - jnp.sum(jnp.where( supp_b, b * phi_star(-(g - geom.potential_from_scaling(b)), rho_b), 0.0)) # Using https://arxiv.org/pdf/1910.12958.pdf (24) if lse_mode: total_sum = jnp.sum(geom.marginal_from_potentials(f, g)) else: total_sum = jnp.sum(geom.marginal_from_scalings( geom.scaling_from_potential(f), geom.scaling_from_potential(g))) return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)
def _sinkhorn_iterations( tau_a: float, tau_b: float, inner_iterations: int, min_iterations: int, max_iterations: int, momentum_default: float, chg_momentum_from: int, lse_mode: bool, implicit_differentiation: bool, linear_solve_kwargs: Mapping[str, Union[Callable, float]], parallel_dual_updates: bool, init_dual_a: jnp.ndarray, init_dual_b: jnp.ndarray, threshold: float, norm_error: Sequence[int], geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """The jittable Sinkhorn loop, that uses a custom backward or not. Args: tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. inner_iterations: (int32) the Sinkhorn error is not recomputed at each iteration but every inner_num_iter instead. min_iterations: (int32) the minimum number of Sinkhorn iterations. max_iterations: (int32) the maximum number of Sinkhorn iterations. momentum_default: float, a float between ]0,2[ chg_momentum_from: int, # of iterations after which momentum is computed lse_mode: True for log-sum-exp computations, False for kernel multiplication. implicit_differentiation: if True, do not backprop through the Sinkhorn loop, but use the implicit function theorem on the fixed point optimality conditions. linear_solve_kwargs: parameterization of linear solver when using implicit differentiation. parallel_dual_updates: updates potentials or scalings in parallel if True, sequentially (in Gauss-Seidel fashion) if False. init_dual_a: optional initialization for potentials/scalings w.r.t. first marginal (``a``) of reg-OT problem. init_dual_b: optional initialization for potentials/scalings w.r.t. second marginal (``b``) of reg-OT problem. threshold: (float) the relative threshold on the Sinkhorn error to stop the Sinkhorn iterations. norm_error: t-uple of int, p-norms of marginal / target errors to track geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. Returns: f: potential g: potential errors: ndarray of errors """ # Initializing solutions f_u, g_v = init_dual_a, init_dual_b # Delete arguments not used in forward pass. del linear_solve_kwargs # Defining the Sinkhorn loop, by setting initializations, body/cond. errors = -jnp.ones((np.ceil(max_iterations / inner_iterations).astype(int), len(norm_error))) const = (geom, a, b, threshold) def cond_fn(iteration, const, state): threshold = const[-1] errors = state[0] err = errors[iteration // inner_iterations-1, 0] return jnp.logical_or(iteration == 0, jnp.logical_and(jnp.isfinite(err), err > threshold)) def get_momentum(errors, idx): """momentum formula, https://arxiv.org/pdf/2012.12562v1.pdf, p.7 and (5).""" error_ratio = jnp.minimum(errors[idx - 1, -1] / errors[idx - 2, -1], .99) power = 1.0 / inner_iterations return 2.0 / (1.0 + jnp.sqrt(1.0 - error_ratio ** power)) def body_fn(iteration, const, state, compute_error): """Carries out sinkhorn iteration. Depending on lse_mode, these iterations can be either in: - log-space for numerical stability. - scaling space, using standard kernel-vector multiply operations. Args: iteration: iteration number const: tuple of constant parameters that do not change throughout the loop, here the geometry and the marginals a, b. state: potential/scaling variables updated in the loop & error log. compute_error: flag to indicate this iteration computes/stores an error Returns: state variables, i.e. errors and updated f_u, g_v potentials. """ geom, a, b, _ = const errors, f_u, g_v = state # compute momentum term if needed, using previously seen errors. w = jax.lax.stop_gradient(jnp.where(iteration >= ( inner_iterations * chg_momentum_from + min_iterations), get_momentum(errors, chg_momentum_from), momentum_default)) # Sinkhorn updates using momentum, in either scaling or potential form. if parallel_dual_updates: old_g_v = g_v if lse_mode: new_g_v = tau_b * geom.update_potential(f_u, g_v, jnp.log(b), iteration, axis=0) g_v = (1.0 - w) * jnp.where(jnp.isfinite(g_v), g_v, 0.0) + w * new_g_v new_f_u = tau_a * geom.update_potential( f_u, old_g_v if parallel_dual_updates else g_v, jnp.log(a), iteration, axis=1) f_u = (1.0 - w) * jnp.where(jnp.isfinite(f_u), f_u, 0.0) + w * new_f_u else: new_g_v = geom.update_scaling(f_u, b, iteration, axis=0) ** tau_b g_v = jnp.where(g_v > 0, g_v, 1) ** (1.0 - w) * new_g_v ** w new_f_u = geom.update_scaling( old_g_v if parallel_dual_updates else g_v, a, iteration, axis=1) ** tau_a f_u = jnp.where(f_u > 0, f_u, 1) ** (1.0 - w) * new_f_u ** w # re-computes error if compute_error is True, else set it to inf. err = jnp.where( jnp.logical_and(compute_error, iteration >= min_iterations), marginal_error(geom, a, b, tau_a, tau_b, f_u, g_v, norm_error, lse_mode), jnp.inf) errors = jax.ops.index_update( errors, jax.ops.index[iteration // inner_iterations, :], err) return errors, f_u, g_v # Run the Sinkhorn loop. choose either a standard fixpoint_iter loop if # differentiation is implicit, otherwise switch to the backprop friendly # version of that loop if unrolling to differentiate. if implicit_differentiation: fix_point = fixed_point_loop.fixpoint_iter else: fix_point = fixed_point_loop.fixpoint_iter_backprop errors, f_u, g_v = fix_point( cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, const, (errors, f_u, g_v)) f = f_u if lse_mode else geom.potential_from_scaling(f_u) g = g_v if lse_mode else geom.potential_from_scaling(g_v) return f, g, errors[:, 0]
def _update_geometry_gw(geom: geometry.Geometry, geom_x: geometry.Geometry, geom_y: geometry.Geometry, f: jnp.ndarray, g: jnp.ndarray, loss: GWLoss, **kwargs) -> geometry.Geometry: """Updates the geometry object for GW by updating the cost matrix. The cost matrix equation follows Equation 6, Proposition 1 of http://proceedings.mlr.press/v48/peyre16.pdf. Let :math:`p` [num_a,] be the marginal of the transport matrix for samples from geom_x and :math:`q` [num_b,] be the marginal of the transport matrix for samples from geom_y. Let :math:`T` [num_a, num_b] be the transport matrix. The cost matrix equation can be written as: cost_matrix = marginal_dep_term + left_x(cost_x) :math:`T` right_y(cost_y):math:`^T` Args: geom: a Geometry object carrying the cost matrix of Gromov Wasserstein. geom_x: a Geometry object for the first view. geom_y: a second Geometry object for the second view. f: jnp.ndarray<float>[num_a,], potentials. g: jnp.ndarray<float>[num_b,], potentials. loss: a GWLossFn object. **kwargs: additional kwargs for epsilon. Returns: A Geometry object for Gromov-Wasserstein. """ def apply_cost_fn(geom): condition = is_sqeuclidean(geom) and isinstance(loss, GWSqEuclLoss) return geom.vec_apply_cost if condition else geom.apply_cost def is_sqeuclidean(geom): return (isinstance(geom, pointcloud.PointCloud) and geom.power == 2.0 and isinstance(geom._cost_fn, costs.Euclidean)) def is_online(geom): return isinstance(geom, pointcloud.PointCloud) and geom._online # Computes tmp = cost_matrix_x * transport if is_online(geom_x) or is_sqeuclidean(geom_x): transport = geom.transport_from_potentials(f, g) tmp = apply_cost_fn(geom_x)(transport, axis=1, fn=loss.left_x) else: tmp = geom.apply_transport_from_potentials(f, g, loss.left_x( geom_x.cost_matrix), axis=0) # Computes cost_matrix marginal_x = geom.marginal_from_potentials(f, g, axis=1) marginal_y = geom.marginal_from_potentials(f, g, axis=0) marginal_dep_term = _marginal_dependent_cost(marginal_x, marginal_y, geom_x, geom_y, loss) cost_matrix = marginal_dep_term - apply_cost_fn(geom_y)( tmp.T, axis=1, fn=loss.right_y).T return geometry.Geometry(cost_matrix=cost_matrix, epsilon=geom._epsilon, **kwargs)
def _discrete_barycenter(geom: geometry.Geometry, a: jnp.ndarray, weights: jnp.ndarray, dual_initialization: jnp.ndarray, threshold: float, norm_error: Sequence[int], inner_iterations: int, min_iterations: int, max_iterations: int, lse_mode: bool, debiased: bool, num_a: int, num_b: int) -> SinkhornBarycenterOutput: """Jit'able function to compute discrete barycenters.""" if lse_mode: f_u = jnp.zeros_like(a) g_v = dual_initialization else: f_u = jnp.ones_like(a) g_v = geom.scaling_from_potential(dual_initialization) # d below is as described in https://arxiv.org/abs/2006.02575. Note that # d should be considered to be equal to eps log(d) with those notations # if running in log-sum-exp mode. d = jnp.zeros((num_b,)) if lse_mode else jnp.ones((num_b,)) if lse_mode: parallel_update = jax.vmap( lambda f, g, marginal, iter: geom.update_potential( f, g, jnp.log(marginal), axis=1), in_axes=[0, 0, 0, None]) parallel_apply = jax.vmap( lambda f_, g_, eps_: geom.apply_lse_kernel( f_, g_, eps_, vec=None, axis=0)[0], in_axes=[0, 0, None]) else: parallel_update = jax.vmap( lambda f, g, marginal, iter: geom.update_scaling(g, marginal, axis=1), in_axes=[0, 0, 0, None]) parallel_apply = jax.vmap( lambda f_, g_, eps_: geom.apply_kernel(f_, eps_, axis=0), in_axes=[0, 0, None]) errors_fn = jax.vmap( functools.partial(geom.error, axis=1, norm_error=norm_error, lse_mode=lse_mode), in_axes=[0, 0, 0]) errors = - jnp.ones( (max_iterations // inner_iterations + 1, len(norm_error))) const = (geom, a, weights) def cond_fn(iteration, const, state): # pylint: disable=unused-argument errors = state[0] return jnp.logical_or( iteration == 0, errors[iteration // inner_iterations - 1, 0] > threshold) def body_fn(iteration, const, state, compute_error): geom, a, weights = const errors, d, f_u, g_v = state eps = geom._epsilon.at(iteration) # pylint: disable=protected-access f_u = parallel_update(f_u, g_v, a, iteration) # kernel_f_u stands for K times potential u if running in scaling mode, # eps log K exp f / eps in lse mode. kernel_f_u = parallel_apply(f_u, g_v, eps) # b below is the running estimate for the barycenter if running in scaling # mode, eps log b if running in lse mode. if lse_mode: b = jnp.average(kernel_f_u, weights=weights, axis=0) else: b = jnp.prod(kernel_f_u ** weights[:, jnp.newaxis], axis=0) if debiased: if lse_mode: b += d d = 0.5 * ( d + geom.update_potential(jnp.zeros((num_a,)), d, b / eps, iteration=iteration, axis=0)) else: b *= d d = jnp.sqrt( d * geom.update_scaling(d, b, iteration=iteration, axis=0)) if lse_mode: g_v = b[jnp.newaxis, :] - kernel_f_u else: g_v = b[jnp.newaxis, :] / kernel_f_u # re-compute error if compute_error is True, else set to inf. err = jnp.where( jnp.logical_and(compute_error, iteration >= min_iterations), jnp.mean(errors_fn(f_u, g_v, a)), jnp.inf) errors = jax.ops.index_update( errors, jax.ops.index[iteration // inner_iterations, :], err) return errors, d, f_u, g_v state = (errors, d, f_u, g_v) state = fixed_point_loop.fixpoint_iter_backprop(cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, const, state) errors, d, f_u, g_v = state kernel_f_u = parallel_apply(f_u, g_v, geom.epsilon) if lse_mode: b = jnp.average(kernel_f_u, weights=weights, axis=0) else: b = jnp.prod(kernel_f_u ** weights[:, jnp.newaxis], axis=0) if debiased: if lse_mode: b += d else: b *= d if lse_mode: b = jnp.exp(b / geom.epsilon) return SinkhornBarycenterOutput(f_u, g_v, b, errors)
def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray: """Computes objective of regularized OT given dual solutions f,g. In all sums below, jnp.where handle situations in which some coordinates of a and b are zero. For those coordinates, their potential is -inf. This leads to -inf - -inf or -inf x 0 operations which result in NaN. These contributions are discarded when computing the objective. Args: geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. f: jnp.ndarray, potential g: jnp.ndarray, potential Returns: a float, the regularized transport cost. """ if tau_a == 1.0: div_a = jnp.sum( jnp.where(a > 0, (f - geom.potential_from_scaling(a)) * a, 0.0)) else: rho_a = geom.epsilon * (tau_a / (1 - tau_a)) div_a = jnp.sum( jnp.where( a > 0, a * (rho_a - (rho_a + geom.epsilon / 2) * jnp.exp(-(f - geom.potential_from_scaling(a)) / rho_a)), 0.0)) if tau_b == 1.0: div_b = jnp.sum( jnp.where(b > 0, (g - geom.potential_from_scaling(b)) * b, 0.0)) else: rho_b = geom.epsilon * (tau_b / (1 - tau_b)) div_b = jnp.sum( jnp.where( b > 0, b * (rho_b - (rho_b + geom.epsilon / 2) * jnp.exp(-(g - geom.potential_from_scaling(b)) / rho_b)), 0.0)) # Using https://arxiv.org/pdf/1910.12958.pdf (30), corrected with (15) # The total mass of the coupling is computed in scaling space. This avoids # differentiation issues linked with the automatic differention of # jnp.exp(jnp.logsumexp(...)) when some of those logs appear as -inf. # Because we are computing total mass it is irrelevant to have underflow since # this would simply result in near 0 contributions, which, unlike Sinkhorn # iterations, do not appear next in a numerator. total_sum = jnp.sum( geom.marginal_from_scalings(geom.scaling_from_potential(f), geom.scaling_from_potential(g))) return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)