Esempio n. 1
0
def _discrete_barycenter(geom: geometry.Geometry,
                         a: jnp.ndarray,
                         weights: jnp.ndarray,
                         dual_initialization: jnp.ndarray,
                         threshold: float,
                         norm_error: Sequence[int],
                         inner_iterations: int,
                         min_iterations: int,
                         max_iterations: int,
                         lse_mode: bool,
                         debiased: bool,
                         num_a: int,
                         num_b: int) -> SinkhornBarycenterOutput:
  """Jit'able function to compute discrete barycenters."""
  if lse_mode:
    f_u = jnp.zeros_like(a)
    g_v = dual_initialization
  else:
    f_u = jnp.ones_like(a)
    g_v = geom.scaling_from_potential(dual_initialization)
  # d below is as described in https://arxiv.org/abs/2006.02575. Note that
  # d should be considered to be equal to eps log(d) with those notations
  # if running in log-sum-exp mode.
  d = jnp.zeros((num_b,)) if lse_mode else jnp.ones((num_b,))

  if lse_mode:
    parallel_update = jax.vmap(
        lambda f, g, marginal, iter: geom.update_potential(
            f, g, jnp.log(marginal), axis=1),
        in_axes=[0, 0, 0, None])
    parallel_apply = jax.vmap(
        lambda f_, g_, eps_: geom.apply_lse_kernel(
            f_, g_, eps_, vec=None, axis=0)[0],
        in_axes=[0, 0, None])
  else:
    parallel_update = jax.vmap(
        lambda f, g, marginal, iter: geom.update_scaling(g, marginal, axis=1),
        in_axes=[0, 0, 0, None])
    parallel_apply = jax.vmap(
        lambda f_, g_, eps_: geom.apply_kernel(f_, eps_, axis=0),
        in_axes=[0, 0, None])

  errors_fn = jax.vmap(
      functools.partial(geom.error, axis=1, norm_error=norm_error,
                        lse_mode=lse_mode),
      in_axes=[0, 0, 0])

  errors = - jnp.ones(
      (max_iterations // inner_iterations + 1, len(norm_error)))

  const = (geom, a, weights)
  def cond_fn(iteration, const, state):  # pylint: disable=unused-argument
    errors = state[0]
    return jnp.logical_or(
        iteration == 0,
        errors[iteration // inner_iterations - 1, 0] > threshold)

  def body_fn(iteration, const, state, compute_error):
    geom, a, weights = const
    errors, d, f_u, g_v = state

    eps = geom._epsilon.at(iteration)  # pylint: disable=protected-access
    f_u = parallel_update(f_u, g_v, a, iteration)
    # kernel_f_u stands for K times potential u if running in scaling mode,
    # eps log K exp f / eps in lse mode.
    kernel_f_u = parallel_apply(f_u, g_v, eps)
    # b below is the running estimate for the barycenter if running in scaling
    # mode, eps log b if running in lse mode.
    if lse_mode:
      b = jnp.average(kernel_f_u, weights=weights, axis=0)
    else:
      b = jnp.prod(kernel_f_u ** weights[:, jnp.newaxis], axis=0)

    if debiased:
      if lse_mode:
        b += d
        d = 0.5 * (
            d +
            geom.update_potential(jnp.zeros((num_a,)), d,
                                  b / eps,
                                  iteration=iteration, axis=0))
      else:
        b *= d
        d = jnp.sqrt(
            d *
            geom.update_scaling(d, b,
                                iteration=iteration, axis=0))
    if lse_mode:
      g_v = b[jnp.newaxis, :] - kernel_f_u
    else:
      g_v = b[jnp.newaxis, :] / kernel_f_u

    # re-compute error if compute_error is True, else set to inf.
    err = jnp.where(
        jnp.logical_and(compute_error, iteration >= min_iterations),
        jnp.mean(errors_fn(f_u, g_v, a)),
        jnp.inf)

    errors = jax.ops.index_update(
        errors, jax.ops.index[iteration // inner_iterations, :], err)
    return errors, d, f_u, g_v

  state = (errors, d, f_u, g_v)

  state = fixed_point_loop.fixpoint_iter_backprop(cond_fn, body_fn,
                                                  min_iterations,
                                                  max_iterations,
                                                  inner_iterations, const,
                                                  state)

  errors, d, f_u, g_v = state
  kernel_f_u = parallel_apply(f_u, g_v, geom.epsilon)
  if lse_mode:
    b = jnp.average(kernel_f_u, weights=weights, axis=0)
  else:
    b = jnp.prod(kernel_f_u ** weights[:, jnp.newaxis], axis=0)

  if debiased:
    if lse_mode:
      b += d
    else:
      b *= d
  if lse_mode:
    b = jnp.exp(b / geom.epsilon)
  return SinkhornBarycenterOutput(f_u, g_v, b, errors)
Esempio n. 2
0
def ent_reg_cost(geom: geometry.Geometry,
                 a: jnp.ndarray,
                 b: jnp.ndarray,
                 tau_a: float,
                 tau_b: float,
                 f: jnp.ndarray,
                 g: jnp.ndarray,
                 lse_mode: bool) -> jnp.ndarray:
  r"""Computes objective of regularized OT given dual solutions ``f``, ``g``.

  The objective is evaluated for dual solution ``f`` and ``g``, using inputs
  ``geom``, ``a`` and ``b``, in addition to parameters ``tau_a``, ``tau_b``.
  Situations where ``a`` or ``b`` have zero coordinates are reflected in
  minus infinity entries in their corresponding dual potentials. To avoid NaN
  that may result when multiplying 0's by infinity values, ``jnp.where`` is
  used to cancel these contributions.

  Args:
    geom: a Geometry object.
    a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights.
    b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights.
    tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    f: jnp.ndarray, potential
    g: jnp.ndarray, potential
    lse_mode: bool, whether to compute total mass in lse or kernel mode.

  Returns:
    a float, the regularized transport cost.
  """
  supp_a = a > 0
  supp_b = b > 0
  if tau_a == 1.0:
    div_a = jnp.sum(
        jnp.where(supp_a, a * (f - geom.potential_from_scaling(a)), 0.0))
  else:
    rho_a = geom.epsilon * (tau_a / (1 - tau_a))
    div_a = - jnp.sum(jnp.where(
        supp_a,
        a * phi_star(-(f - geom.potential_from_scaling(a)), rho_a),
        0.0))

  if tau_b == 1.0:
    div_b = jnp.sum(
        jnp.where(supp_b, b * (g - geom.potential_from_scaling(b)), 0.0))
  else:
    rho_b = geom.epsilon * (tau_b / (1 - tau_b))
    div_b = - jnp.sum(jnp.where(
        supp_b,
        b * phi_star(-(g - geom.potential_from_scaling(b)), rho_b),
        0.0))

  # Using https://arxiv.org/pdf/1910.12958.pdf (24)
  if lse_mode:
    total_sum = jnp.sum(geom.marginal_from_potentials(f, g))
  else:
    total_sum = jnp.sum(geom.marginal_from_scalings(
        geom.scaling_from_potential(f), geom.scaling_from_potential(g)))
  return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)
Esempio n. 3
0
def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray,
                 tau_a: float, tau_b: float, f: jnp.ndarray,
                 g: jnp.ndarray) -> jnp.ndarray:
    """Computes objective of regularized OT given dual solutions f,g.

  In all sums below, jnp.where handle situations in which some coordinates of
  a and b are zero. For those coordinates, their potential is -inf.
  This leads to -inf - -inf or -inf x 0 operations which result in NaN.
  These contributions are discarded when computing the objective.

  Args:
    geom: a Geometry object.
    a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights.
    b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights.
    tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    f: jnp.ndarray, potential
    g: jnp.ndarray, potential

  Returns:
    a float, the regularized transport cost.
  """

    if tau_a == 1.0:
        div_a = jnp.sum(
            jnp.where(a > 0, (f - geom.potential_from_scaling(a)) * a, 0.0))
    else:
        rho_a = geom.epsilon * (tau_a / (1 - tau_a))
        div_a = jnp.sum(
            jnp.where(
                a > 0,
                a * (rho_a - (rho_a + geom.epsilon / 2) *
                     jnp.exp(-(f - geom.potential_from_scaling(a)) / rho_a)),
                0.0))

    if tau_b == 1.0:
        div_b = jnp.sum(
            jnp.where(b > 0, (g - geom.potential_from_scaling(b)) * b, 0.0))
    else:
        rho_b = geom.epsilon * (tau_b / (1 - tau_b))
        div_b = jnp.sum(
            jnp.where(
                b > 0,
                b * (rho_b - (rho_b + geom.epsilon / 2) *
                     jnp.exp(-(g - geom.potential_from_scaling(b)) / rho_b)),
                0.0))

    # Using https://arxiv.org/pdf/1910.12958.pdf (30), corrected with (15)
    # The total mass of the coupling is computed in scaling space. This avoids
    # differentiation issues linked with the automatic differention of
    # jnp.exp(jnp.logsumexp(...)) when some of those logs appear as -inf.
    # Because we are computing total mass it is irrelevant to have underflow since
    # this would simply result in near 0 contributions, which, unlike Sinkhorn
    # iterations, do not appear next in a numerator.
    total_sum = jnp.sum(
        geom.marginal_from_scalings(geom.scaling_from_potential(f),
                                    geom.scaling_from_potential(g)))
    return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)