Exemple #1
0
def ent_reg_cost(geom: geometry.Geometry,
                 a: jnp.ndarray,
                 b: jnp.ndarray,
                 tau_a: float,
                 tau_b: float,
                 f: jnp.ndarray,
                 g: jnp.ndarray,
                 lse_mode: bool) -> jnp.ndarray:
  r"""Computes objective of regularized OT given dual solutions ``f``, ``g``.

  The objective is evaluated for dual solution ``f`` and ``g``, using inputs
  ``geom``, ``a`` and ``b``, in addition to parameters ``tau_a``, ``tau_b``.
  Situations where ``a`` or ``b`` have zero coordinates are reflected in
  minus infinity entries in their corresponding dual potentials. To avoid NaN
  that may result when multiplying 0's by infinity values, ``jnp.where`` is
  used to cancel these contributions.

  Args:
    geom: a Geometry object.
    a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights.
    b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights.
    tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    f: jnp.ndarray, potential
    g: jnp.ndarray, potential
    lse_mode: bool, whether to compute total mass in lse or kernel mode.

  Returns:
    a float, the regularized transport cost.
  """
  supp_a = a > 0
  supp_b = b > 0
  if tau_a == 1.0:
    div_a = jnp.sum(
        jnp.where(supp_a, a * (f - geom.potential_from_scaling(a)), 0.0))
  else:
    rho_a = geom.epsilon * (tau_a / (1 - tau_a))
    div_a = - jnp.sum(jnp.where(
        supp_a,
        a * phi_star(-(f - geom.potential_from_scaling(a)), rho_a),
        0.0))

  if tau_b == 1.0:
    div_b = jnp.sum(
        jnp.where(supp_b, b * (g - geom.potential_from_scaling(b)), 0.0))
  else:
    rho_b = geom.epsilon * (tau_b / (1 - tau_b))
    div_b = - jnp.sum(jnp.where(
        supp_b,
        b * phi_star(-(g - geom.potential_from_scaling(b)), rho_b),
        0.0))

  # Using https://arxiv.org/pdf/1910.12958.pdf (24)
  if lse_mode:
    total_sum = jnp.sum(geom.marginal_from_potentials(f, g))
  else:
    total_sum = jnp.sum(geom.marginal_from_scalings(
        geom.scaling_from_potential(f), geom.scaling_from_potential(g)))
  return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)
Exemple #2
0
def _update_geometry_gw(geom: geometry.Geometry, geom_x: geometry.Geometry,
                        geom_y: geometry.Geometry, f: jnp.ndarray,
                        g: jnp.ndarray, loss: GWLoss,
                        **kwargs) -> geometry.Geometry:
    """Updates the geometry object for GW by updating the cost matrix.

  The cost matrix equation follows Equation 6, Proposition 1 of
  http://proceedings.mlr.press/v48/peyre16.pdf.

  Let :math:`p` [num_a,] be the marginal of the transport matrix for samples
  from geom_x and :math:`q` [num_b,] be the marginal of the transport matrix for
  samples from geom_y. Let :math:`T` [num_a, num_b] be the transport matrix.
  The cost matrix equation can be written as:

  cost_matrix = marginal_dep_term
              + left_x(cost_x) :math:`T` right_y(cost_y):math:`^T`

  Args:
    geom: a Geometry object carrying the cost matrix of Gromov Wasserstein.
    geom_x: a Geometry object for the first view.
    geom_y: a second Geometry object for the second view.
    f: jnp.ndarray<float>[num_a,], potentials.
    g: jnp.ndarray<float>[num_b,], potentials.
    loss: a GWLossFn object.
    **kwargs: additional kwargs for epsilon.

  Returns:
    A Geometry object for Gromov-Wasserstein.
  """
    def apply_cost_fn(geom):
        condition = is_sqeuclidean(geom) and isinstance(loss, GWSqEuclLoss)
        return geom.vec_apply_cost if condition else geom.apply_cost

    def is_sqeuclidean(geom):
        return (isinstance(geom, pointcloud.PointCloud) and geom.power == 2.0
                and isinstance(geom._cost_fn, costs.Euclidean))

    def is_online(geom):
        return isinstance(geom, pointcloud.PointCloud) and geom._online

    # Computes tmp = cost_matrix_x * transport
    if is_online(geom_x) or is_sqeuclidean(geom_x):
        transport = geom.transport_from_potentials(f, g)
        tmp = apply_cost_fn(geom_x)(transport, axis=1, fn=loss.left_x)
    else:
        tmp = geom.apply_transport_from_potentials(f,
                                                   g,
                                                   loss.left_x(
                                                       geom_x.cost_matrix),
                                                   axis=0)

    # Computes cost_matrix
    marginal_x = geom.marginal_from_potentials(f, g, axis=1)
    marginal_y = geom.marginal_from_potentials(f, g, axis=0)
    marginal_dep_term = _marginal_dependent_cost(marginal_x, marginal_y,
                                                 geom_x, geom_y, loss)
    cost_matrix = marginal_dep_term - apply_cost_fn(geom_y)(
        tmp.T, axis=1, fn=loss.right_y).T
    return geometry.Geometry(cost_matrix=cost_matrix,
                             epsilon=geom._epsilon,
                             **kwargs)