def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool) -> jnp.ndarray: r"""Computes objective of regularized OT given dual solutions ``f``, ``g``. The objective is evaluated for dual solution ``f`` and ``g``, using inputs ``geom``, ``a`` and ``b``, in addition to parameters ``tau_a``, ``tau_b``. Situations where ``a`` or ``b`` have zero coordinates are reflected in minus infinity entries in their corresponding dual potentials. To avoid NaN that may result when multiplying 0's by infinity values, ``jnp.where`` is used to cancel these contributions. Args: geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. f: jnp.ndarray, potential g: jnp.ndarray, potential lse_mode: bool, whether to compute total mass in lse or kernel mode. Returns: a float, the regularized transport cost. """ supp_a = a > 0 supp_b = b > 0 if tau_a == 1.0: div_a = jnp.sum( jnp.where(supp_a, a * (f - geom.potential_from_scaling(a)), 0.0)) else: rho_a = geom.epsilon * (tau_a / (1 - tau_a)) div_a = - jnp.sum(jnp.where( supp_a, a * phi_star(-(f - geom.potential_from_scaling(a)), rho_a), 0.0)) if tau_b == 1.0: div_b = jnp.sum( jnp.where(supp_b, b * (g - geom.potential_from_scaling(b)), 0.0)) else: rho_b = geom.epsilon * (tau_b / (1 - tau_b)) div_b = - jnp.sum(jnp.where( supp_b, b * phi_star(-(g - geom.potential_from_scaling(b)), rho_b), 0.0)) # Using https://arxiv.org/pdf/1910.12958.pdf (24) if lse_mode: total_sum = jnp.sum(geom.marginal_from_potentials(f, g)) else: total_sum = jnp.sum(geom.marginal_from_scalings( geom.scaling_from_potential(f), geom.scaling_from_potential(g))) return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)
def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray: """Computes objective of regularized OT given dual solutions f,g. In all sums below, jnp.where handle situations in which some coordinates of a and b are zero. For those coordinates, their potential is -inf. This leads to -inf - -inf or -inf x 0 operations which result in NaN. These contributions are discarded when computing the objective. Args: geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. f: jnp.ndarray, potential g: jnp.ndarray, potential Returns: a float, the regularized transport cost. """ if tau_a == 1.0: div_a = jnp.sum( jnp.where(a > 0, (f - geom.potential_from_scaling(a)) * a, 0.0)) else: rho_a = geom.epsilon * (tau_a / (1 - tau_a)) div_a = jnp.sum( jnp.where( a > 0, a * (rho_a - (rho_a + geom.epsilon / 2) * jnp.exp(-(f - geom.potential_from_scaling(a)) / rho_a)), 0.0)) if tau_b == 1.0: div_b = jnp.sum( jnp.where(b > 0, (g - geom.potential_from_scaling(b)) * b, 0.0)) else: rho_b = geom.epsilon * (tau_b / (1 - tau_b)) div_b = jnp.sum( jnp.where( b > 0, b * (rho_b - (rho_b + geom.epsilon / 2) * jnp.exp(-(g - geom.potential_from_scaling(b)) / rho_b)), 0.0)) # Using https://arxiv.org/pdf/1910.12958.pdf (30), corrected with (15) # The total mass of the coupling is computed in scaling space. This avoids # differentiation issues linked with the automatic differention of # jnp.exp(jnp.logsumexp(...)) when some of those logs appear as -inf. # Because we are computing total mass it is irrelevant to have underflow since # this would simply result in near 0 contributions, which, unlike Sinkhorn # iterations, do not appear next in a numerator. total_sum = jnp.sum( geom.marginal_from_scalings(geom.scaling_from_potential(f), geom.scaling_from_potential(g))) return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)