def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray: """Computes objective of regularized OT given dual solutions f,g.""" # In all sums below, jnp.where handle situations in which some coordinates of # a and b are zero. For those coordinates, their potential is -inf. # This leads to -inf - -inf or -inf x 0 operations which result in NaN. # These contributions are discarded when computing the objective. if tau_a == 1.0: div_a = jnp.sum( jnp.where(a > 0, (f - geom.potential_from_scaling(a)) * a, 0)) else: rho_a = geom.epsilon * (tau_a / (1 - tau_a)) div_a = jnp.sum( jnp.where( a > 0, a * (rho_a - (rho_a + geom.epsilon / 2) * jnp.exp(-(f - geom.potential_from_scaling(a)) / rho_a)), 0)) if tau_b == 1.0: div_b = jnp.sum( jnp.where(b > 0, (g - geom.potential_from_scaling(b)) * b, 0)) else: rho_b = geom.epsilon * (tau_b / (1 - tau_b)) div_b = jnp.sum( jnp.where( b > 0, b * (rho_b - (rho_b + geom.epsilon / 2) * jnp.exp(-(g - geom.potential_from_scaling(b)) / rho_b)), 0)) # Using https://arxiv.org/pdf/1910.12958.pdf Eq. 30 return div_a + div_b + geom.epsilon * jnp.sum(a) * jnp.sum(b)
def marginal_error(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f_u: jnp.ndarray, g_v: jnp.ndarray, norm_error: int, lse_mode) -> jnp.ndarray: """Conputes marginal error, the stopping criterion used to terminate Sinkhorn. Args: geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. f_u: jnp.ndarray, potential or scaling g_v: jnp.ndarray, potential or scaling norm_error: int, p-norm used to compute error. lse_mode: True if log-sum-exp operations, False if kernel vector producs. Returns: a positive number quantifying how far from convergence the algorithm stands. """ if tau_a == 1.0 and tau_b == 1.0: err = geom.error(f_u, g_v, b, 0, norm_error, lse_mode) else: # In the unbalanced case, we compute the norm of the gradient. # the gradient is equal to the marginal of the current plan minus # the gradient of < z, rho_z(exp^(-h/rho_z) -1> where z is either a or b # and h is either f or g. Note this is equal to z if rho_z → inf, which # is the case when tau_z → 1.0 if lse_mode: grad_a = grad_of_marginal_fit(a, f_u, tau_a, geom.epsilon) grad_b = grad_of_marginal_fit(b, g_v, tau_b, geom.epsilon) else: grad_a = grad_of_marginal_fit(a, geom.potential_from_scaling(f_u), tau_a, geom.epsilon) grad_b = grad_of_marginal_fit(b, geom.potential_from_scaling(g_v), tau_b, geom.epsilon) err = geom.error(f_u, g_v, grad_a, 1, norm_error, lse_mode) err += geom.error(f_u, g_v, grad_b, 0, norm_error, lse_mode) return err
def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool) -> jnp.ndarray: r"""Computes objective of regularized OT given dual solutions ``f``, ``g``. The objective is evaluated for dual solution ``f`` and ``g``, using inputs ``geom``, ``a`` and ``b``, in addition to parameters ``tau_a``, ``tau_b``. Situations where ``a`` or ``b`` have zero coordinates are reflected in minus infinity entries in their corresponding dual potentials. To avoid NaN that may result when multiplying 0's by infinity values, ``jnp.where`` is used to cancel these contributions. Args: geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. f: jnp.ndarray, potential g: jnp.ndarray, potential lse_mode: bool, whether to compute total mass in lse or kernel mode. Returns: a float, the regularized transport cost. """ supp_a = a > 0 supp_b = b > 0 if tau_a == 1.0: div_a = jnp.sum( jnp.where(supp_a, a * (f - geom.potential_from_scaling(a)), 0.0)) else: rho_a = geom.epsilon * (tau_a / (1 - tau_a)) div_a = - jnp.sum(jnp.where( supp_a, a * phi_star(-(f - geom.potential_from_scaling(a)), rho_a), 0.0)) if tau_b == 1.0: div_b = jnp.sum( jnp.where(supp_b, b * (g - geom.potential_from_scaling(b)), 0.0)) else: rho_b = geom.epsilon * (tau_b / (1 - tau_b)) div_b = - jnp.sum(jnp.where( supp_b, b * phi_star(-(g - geom.potential_from_scaling(b)), rho_b), 0.0)) # Using https://arxiv.org/pdf/1910.12958.pdf (24) if lse_mode: total_sum = jnp.sum(geom.marginal_from_potentials(f, g)) else: total_sum = jnp.sum(geom.marginal_from_scalings( geom.scaling_from_potential(f), geom.scaling_from_potential(g))) return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)
def _sinkhorn_iterations( tau_a: float, tau_b: float, inner_iterations: int, min_iterations: int, max_iterations: int, momentum_default: float, chg_momentum_from: int, lse_mode: bool, implicit_differentiation: bool, linear_solve_kwargs: Mapping[str, Union[Callable, float]], parallel_dual_updates: bool, init_dual_a: jnp.ndarray, init_dual_b: jnp.ndarray, threshold: float, norm_error: Sequence[int], geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """The jittable Sinkhorn loop, that uses a custom backward or not. Args: tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. inner_iterations: (int32) the Sinkhorn error is not recomputed at each iteration but every inner_num_iter instead. min_iterations: (int32) the minimum number of Sinkhorn iterations. max_iterations: (int32) the maximum number of Sinkhorn iterations. momentum_default: float, a float between ]0,2[ chg_momentum_from: int, # of iterations after which momentum is computed lse_mode: True for log-sum-exp computations, False for kernel multiplication. implicit_differentiation: if True, do not backprop through the Sinkhorn loop, but use the implicit function theorem on the fixed point optimality conditions. linear_solve_kwargs: parameterization of linear solver when using implicit differentiation. parallel_dual_updates: updates potentials or scalings in parallel if True, sequentially (in Gauss-Seidel fashion) if False. init_dual_a: optional initialization for potentials/scalings w.r.t. first marginal (``a``) of reg-OT problem. init_dual_b: optional initialization for potentials/scalings w.r.t. second marginal (``b``) of reg-OT problem. threshold: (float) the relative threshold on the Sinkhorn error to stop the Sinkhorn iterations. norm_error: t-uple of int, p-norms of marginal / target errors to track geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. Returns: f: potential g: potential errors: ndarray of errors """ # Initializing solutions f_u, g_v = init_dual_a, init_dual_b # Delete arguments not used in forward pass. del linear_solve_kwargs # Defining the Sinkhorn loop, by setting initializations, body/cond. errors = -jnp.ones((np.ceil(max_iterations / inner_iterations).astype(int), len(norm_error))) const = (geom, a, b, threshold) def cond_fn(iteration, const, state): threshold = const[-1] errors = state[0] err = errors[iteration // inner_iterations-1, 0] return jnp.logical_or(iteration == 0, jnp.logical_and(jnp.isfinite(err), err > threshold)) def get_momentum(errors, idx): """momentum formula, https://arxiv.org/pdf/2012.12562v1.pdf, p.7 and (5).""" error_ratio = jnp.minimum(errors[idx - 1, -1] / errors[idx - 2, -1], .99) power = 1.0 / inner_iterations return 2.0 / (1.0 + jnp.sqrt(1.0 - error_ratio ** power)) def body_fn(iteration, const, state, compute_error): """Carries out sinkhorn iteration. Depending on lse_mode, these iterations can be either in: - log-space for numerical stability. - scaling space, using standard kernel-vector multiply operations. Args: iteration: iteration number const: tuple of constant parameters that do not change throughout the loop, here the geometry and the marginals a, b. state: potential/scaling variables updated in the loop & error log. compute_error: flag to indicate this iteration computes/stores an error Returns: state variables, i.e. errors and updated f_u, g_v potentials. """ geom, a, b, _ = const errors, f_u, g_v = state # compute momentum term if needed, using previously seen errors. w = jax.lax.stop_gradient(jnp.where(iteration >= ( inner_iterations * chg_momentum_from + min_iterations), get_momentum(errors, chg_momentum_from), momentum_default)) # Sinkhorn updates using momentum, in either scaling or potential form. if parallel_dual_updates: old_g_v = g_v if lse_mode: new_g_v = tau_b * geom.update_potential(f_u, g_v, jnp.log(b), iteration, axis=0) g_v = (1.0 - w) * jnp.where(jnp.isfinite(g_v), g_v, 0.0) + w * new_g_v new_f_u = tau_a * geom.update_potential( f_u, old_g_v if parallel_dual_updates else g_v, jnp.log(a), iteration, axis=1) f_u = (1.0 - w) * jnp.where(jnp.isfinite(f_u), f_u, 0.0) + w * new_f_u else: new_g_v = geom.update_scaling(f_u, b, iteration, axis=0) ** tau_b g_v = jnp.where(g_v > 0, g_v, 1) ** (1.0 - w) * new_g_v ** w new_f_u = geom.update_scaling( old_g_v if parallel_dual_updates else g_v, a, iteration, axis=1) ** tau_a f_u = jnp.where(f_u > 0, f_u, 1) ** (1.0 - w) * new_f_u ** w # re-computes error if compute_error is True, else set it to inf. err = jnp.where( jnp.logical_and(compute_error, iteration >= min_iterations), marginal_error(geom, a, b, tau_a, tau_b, f_u, g_v, norm_error, lse_mode), jnp.inf) errors = jax.ops.index_update( errors, jax.ops.index[iteration // inner_iterations, :], err) return errors, f_u, g_v # Run the Sinkhorn loop. choose either a standard fixpoint_iter loop if # differentiation is implicit, otherwise switch to the backprop friendly # version of that loop if unrolling to differentiate. if implicit_differentiation: fix_point = fixed_point_loop.fixpoint_iter else: fix_point = fixed_point_loop.fixpoint_iter_backprop errors, f_u, g_v = fix_point( cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, const, (errors, f_u, g_v)) f = f_u if lse_mode else geom.potential_from_scaling(f_u) g = g_v if lse_mode else geom.potential_from_scaling(g_v) return f, g, errors[:, 0]
def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray: """Computes objective of regularized OT given dual solutions f,g. In all sums below, jnp.where handle situations in which some coordinates of a and b are zero. For those coordinates, their potential is -inf. This leads to -inf - -inf or -inf x 0 operations which result in NaN. These contributions are discarded when computing the objective. Args: geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. f: jnp.ndarray, potential g: jnp.ndarray, potential Returns: a float, the regularized transport cost. """ if tau_a == 1.0: div_a = jnp.sum( jnp.where(a > 0, (f - geom.potential_from_scaling(a)) * a, 0.0)) else: rho_a = geom.epsilon * (tau_a / (1 - tau_a)) div_a = jnp.sum( jnp.where( a > 0, a * (rho_a - (rho_a + geom.epsilon / 2) * jnp.exp(-(f - geom.potential_from_scaling(a)) / rho_a)), 0.0)) if tau_b == 1.0: div_b = jnp.sum( jnp.where(b > 0, (g - geom.potential_from_scaling(b)) * b, 0.0)) else: rho_b = geom.epsilon * (tau_b / (1 - tau_b)) div_b = jnp.sum( jnp.where( b > 0, b * (rho_b - (rho_b + geom.epsilon / 2) * jnp.exp(-(g - geom.potential_from_scaling(b)) / rho_b)), 0.0)) # Using https://arxiv.org/pdf/1910.12958.pdf (30), corrected with (15) # The total mass of the coupling is computed in scaling space. This avoids # differentiation issues linked with the automatic differention of # jnp.exp(jnp.logsumexp(...)) when some of those logs appear as -inf. # Because we are computing total mass it is irrelevant to have underflow since # this would simply result in near 0 contributions, which, unlike Sinkhorn # iterations, do not appear next in a numerator. total_sum = jnp.sum( geom.marginal_from_scalings(geom.scaling_from_potential(f), geom.scaling_from_potential(g))) return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)