def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray, tau_a: float, tau_b: float, f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool) -> jnp.ndarray: r"""Computes objective of regularized OT given dual solutions ``f``, ``g``. The objective is evaluated for dual solution ``f`` and ``g``, using inputs ``geom``, ``a`` and ``b``, in addition to parameters ``tau_a``, ``tau_b``. Situations where ``a`` or ``b`` have zero coordinates are reflected in minus infinity entries in their corresponding dual potentials. To avoid NaN that may result when multiplying 0's by infinity values, ``jnp.where`` is used to cancel these contributions. Args: geom: a Geometry object. a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights. b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights. tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first marginal and itself + epsilon regularizer used in the unbalanced formulation. f: jnp.ndarray, potential g: jnp.ndarray, potential lse_mode: bool, whether to compute total mass in lse or kernel mode. Returns: a float, the regularized transport cost. """ supp_a = a > 0 supp_b = b > 0 if tau_a == 1.0: div_a = jnp.sum( jnp.where(supp_a, a * (f - geom.potential_from_scaling(a)), 0.0)) else: rho_a = geom.epsilon * (tau_a / (1 - tau_a)) div_a = - jnp.sum(jnp.where( supp_a, a * phi_star(-(f - geom.potential_from_scaling(a)), rho_a), 0.0)) if tau_b == 1.0: div_b = jnp.sum( jnp.where(supp_b, b * (g - geom.potential_from_scaling(b)), 0.0)) else: rho_b = geom.epsilon * (tau_b / (1 - tau_b)) div_b = - jnp.sum(jnp.where( supp_b, b * phi_star(-(g - geom.potential_from_scaling(b)), rho_b), 0.0)) # Using (24) if lse_mode: total_sum = jnp.sum(geom.marginal_from_potentials(f, g)) else: total_sum = jnp.sum(geom.marginal_from_scalings( geom.scaling_from_potential(f), geom.scaling_from_potential(g))) return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)
def _update_geometry_gw(geom: geometry.Geometry, geom_x: geometry.Geometry, geom_y: geometry.Geometry, f: jnp.ndarray, g: jnp.ndarray, loss: GWLoss, **kwargs) -> geometry.Geometry: """Updates the geometry object for GW by updating the cost matrix. The cost matrix equation follows Equation 6, Proposition 1 of Let :math:`p` [num_a,] be the marginal of the transport matrix for samples from geom_x and :math:`q` [num_b,] be the marginal of the transport matrix for samples from geom_y. Let :math:`T` [num_a, num_b] be the transport matrix. The cost matrix equation can be written as: cost_matrix = marginal_dep_term + left_x(cost_x) :math:`T` right_y(cost_y):math:`^T` Args: geom: a Geometry object carrying the cost matrix of Gromov Wasserstein. geom_x: a Geometry object for the first view. geom_y: a second Geometry object for the second view. f: jnp.ndarray<float>[num_a,], potentials. g: jnp.ndarray<float>[num_b,], potentials. loss: a GWLossFn object. **kwargs: additional kwargs for epsilon. Returns: A Geometry object for Gromov-Wasserstein. """ def apply_cost_fn(geom): condition = is_sqeuclidean(geom) and isinstance(loss, GWSqEuclLoss) return geom.vec_apply_cost if condition else geom.apply_cost def is_sqeuclidean(geom): return (isinstance(geom, pointcloud.PointCloud) and geom.power == 2.0 and isinstance(geom._cost_fn, costs.Euclidean)) def is_online(geom): return isinstance(geom, pointcloud.PointCloud) and geom._online # Computes tmp = cost_matrix_x * transport if is_online(geom_x) or is_sqeuclidean(geom_x): transport = geom.transport_from_potentials(f, g) tmp = apply_cost_fn(geom_x)(transport, axis=1, fn=loss.left_x) else: tmp = geom.apply_transport_from_potentials(f, g, loss.left_x( geom_x.cost_matrix), axis=0) # Computes cost_matrix marginal_x = geom.marginal_from_potentials(f, g, axis=1) marginal_y = geom.marginal_from_potentials(f, g, axis=0) marginal_dep_term = _marginal_dependent_cost(marginal_x, marginal_y, geom_x, geom_y, loss) cost_matrix = marginal_dep_term - apply_cost_fn(geom_y)( tmp.T, axis=1, fn=loss.right_y).T return geometry.Geometry(cost_matrix=cost_matrix, epsilon=geom._epsilon, **kwargs)