def _discrete_barycenter(geom: geometry.Geometry, a: jnp.ndarray, weights: jnp.ndarray, dual_initialization: jnp.ndarray, threshold: float, norm_error: Sequence[int], inner_iterations: int, min_iterations: int, max_iterations: int, lse_mode: bool, debiased: bool, num_a: int, num_b: int) -> SinkhornBarycenterOutput: """Jit'able function to compute discrete barycenters.""" if lse_mode: f_u = jnp.zeros_like(a) g_v = dual_initialization else: f_u = jnp.ones_like(a) g_v = geom.scaling_from_potential(dual_initialization) # d below is as described in https://arxiv.org/abs/2006.02575. Note that # d should be considered to be equal to eps log(d) with those notations # if running in log-sum-exp mode. d = jnp.zeros((num_b,)) if lse_mode else jnp.ones((num_b,)) if lse_mode: parallel_update = jax.vmap( lambda f, g, marginal, iter: geom.update_potential( f, g, jnp.log(marginal), axis=1), in_axes=[0, 0, 0, None]) parallel_apply = jax.vmap( lambda f_, g_, eps_: geom.apply_lse_kernel( f_, g_, eps_, vec=None, axis=0)[0], in_axes=[0, 0, None]) else: parallel_update = jax.vmap( lambda f, g, marginal, iter: geom.update_scaling(g, marginal, axis=1), in_axes=[0, 0, 0, None]) parallel_apply = jax.vmap( lambda f_, g_, eps_: geom.apply_kernel(f_, eps_, axis=0), in_axes=[0, 0, None]) errors_fn = jax.vmap( functools.partial(geom.error, axis=1, norm_error=norm_error, lse_mode=lse_mode), in_axes=[0, 0, 0]) errors = - jnp.ones( (max_iterations // inner_iterations + 1, len(norm_error))) const = (geom, a, weights) def cond_fn(iteration, const, state): # pylint: disable=unused-argument errors = state[0] return jnp.logical_or( iteration == 0, errors[iteration // inner_iterations - 1, 0] > threshold) def body_fn(iteration, const, state, compute_error): geom, a, weights = const errors, d, f_u, g_v = state eps = geom._epsilon.at(iteration) # pylint: disable=protected-access f_u = parallel_update(f_u, g_v, a, iteration) # kernel_f_u stands for K times potential u if running in scaling mode, # eps log K exp f / eps in lse mode. kernel_f_u = parallel_apply(f_u, g_v, eps) # b below is the running estimate for the barycenter if running in scaling # mode, eps log b if running in lse mode. if lse_mode: b = jnp.average(kernel_f_u, weights=weights, axis=0) else: b = jnp.prod(kernel_f_u ** weights[:, jnp.newaxis], axis=0) if debiased: if lse_mode: b += d d = 0.5 * ( d + geom.update_potential(jnp.zeros((num_a,)), d, b / eps, iteration=iteration, axis=0)) else: b *= d d = jnp.sqrt( d * geom.update_scaling(d, b, iteration=iteration, axis=0)) if lse_mode: g_v = b[jnp.newaxis, :] - kernel_f_u else: g_v = b[jnp.newaxis, :] / kernel_f_u # re-compute error if compute_error is True, else set to inf. err = jnp.where( jnp.logical_and(compute_error, iteration >= min_iterations), jnp.mean(errors_fn(f_u, g_v, a)), jnp.inf) errors = jax.ops.index_update( errors, jax.ops.index[iteration // inner_iterations, :], err) return errors, d, f_u, g_v state = (errors, d, f_u, g_v) state = fixed_point_loop.fixpoint_iter_backprop(cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, const, state) errors, d, f_u, g_v = state kernel_f_u = parallel_apply(f_u, g_v, geom.epsilon) if lse_mode: b = jnp.average(kernel_f_u, weights=weights, axis=0) else: b = jnp.prod(kernel_f_u ** weights[:, jnp.newaxis], axis=0) if debiased: if lse_mode: b += d else: b *= d if lse_mode: b = jnp.exp(b / geom.epsilon) return SinkhornBarycenterOutput(f_u, g_v, b, errors)
def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool) -> jnp.ndarray: r"""Computes objective of regularized OT given dual solutions ``f``, ``g``. The objective is evaluated for dual solution ``f`` and ``g``, using inputs ``geom``, ``a`` and ``b``, in addition to parameters ``tau_a``, ``tau_b``. Situations where ``a`` or ``b`` have zero coordinates are reflected in minus infinity entries in their corresponding dual potentials. To avoid NaN that may result when multiplying 0's by infinity values, ``jnp.where`` is used to cancel these contributions. Args: geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. f: jnp.ndarray, potential g: jnp.ndarray, potential lse_mode: bool, whether to compute total mass in lse or kernel mode. Returns: a float, the regularized transport cost. """ supp_a = a > 0 supp_b = b > 0 if tau_a == 1.0: div_a = jnp.sum( jnp.where(supp_a, a * (f - geom.potential_from_scaling(a)), 0.0)) else: rho_a = geom.epsilon * (tau_a / (1 - tau_a)) div_a = - jnp.sum(jnp.where( supp_a, a * phi_star(-(f - geom.potential_from_scaling(a)), rho_a), 0.0)) if tau_b == 1.0: div_b = jnp.sum( jnp.where(supp_b, b * (g - geom.potential_from_scaling(b)), 0.0)) else: rho_b = geom.epsilon * (tau_b / (1 - tau_b)) div_b = - jnp.sum(jnp.where( supp_b, b * phi_star(-(g - geom.potential_from_scaling(b)), rho_b), 0.0)) # Using https://arxiv.org/pdf/1910.12958.pdf (24) if lse_mode: total_sum = jnp.sum(geom.marginal_from_potentials(f, g)) else: total_sum = jnp.sum(geom.marginal_from_scalings( geom.scaling_from_potential(f), geom.scaling_from_potential(g))) return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)
def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray: """Computes objective of regularized OT given dual solutions f,g. In all sums below, jnp.where handle situations in which some coordinates of a and b are zero. For those coordinates, their potential is -inf. This leads to -inf - -inf or -inf x 0 operations which result in NaN. These contributions are discarded when computing the objective. Args: geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. f: jnp.ndarray, potential g: jnp.ndarray, potential Returns: a float, the regularized transport cost. """ if tau_a == 1.0: div_a = jnp.sum( jnp.where(a > 0, (f - geom.potential_from_scaling(a)) * a, 0.0)) else: rho_a = geom.epsilon * (tau_a / (1 - tau_a)) div_a = jnp.sum( jnp.where( a > 0, a * (rho_a - (rho_a + geom.epsilon / 2) * jnp.exp(-(f - geom.potential_from_scaling(a)) / rho_a)), 0.0)) if tau_b == 1.0: div_b = jnp.sum( jnp.where(b > 0, (g - geom.potential_from_scaling(b)) * b, 0.0)) else: rho_b = geom.epsilon * (tau_b / (1 - tau_b)) div_b = jnp.sum( jnp.where( b > 0, b * (rho_b - (rho_b + geom.epsilon / 2) * jnp.exp(-(g - geom.potential_from_scaling(b)) / rho_b)), 0.0)) # Using https://arxiv.org/pdf/1910.12958.pdf (30), corrected with (15) # The total mass of the coupling is computed in scaling space. This avoids # differentiation issues linked with the automatic differention of # jnp.exp(jnp.logsumexp(...)) when some of those logs appear as -inf. # Because we are computing total mass it is irrelevant to have underflow since # this would simply result in near 0 contributions, which, unlike Sinkhorn # iterations, do not appear next in a numerator. total_sum = jnp.sum( geom.marginal_from_scalings(geom.scaling_from_potential(f), geom.scaling_from_potential(g))) return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)