Beispiel #1
0
def _init_geometry_gw(geom_x: geometry.Geometry, geom_y: geometry.Geometry,
                      a: jnp.ndarray, b: jnp.ndarray,
                      epsilon: Union[epsilon_scheduler.Epsilon, float],
                      loss: GWLoss, **kwargs) -> geometry.Geometry:
    """Initialises the cost matrix for the geometry object for GW.

  The equation follows Equation 6, Proposition 1 of
  http://proceedings.mlr.press/v48/peyre16.pdf.

  Args:
    geom_x: a Geometry object for the first view.
    geom_y: a second Geometry object for the second view.
    a: jnp.ndarray<float>[num_a,], weights.
    b: jnp.ndarray<float>[num_b,], weights.
    epsilon: a regularization parameter or a epsilon_scheduler.Epsilon object.
    loss: a GWLossFn object.
    **kwargs: additional kwargs to epsilon.

  Returns:
    A Geometry object for Gromov-Wasserstein.
  """
    # Initialization of the transport matrix in the balanced case, following
    # http://proceedings.mlr.press/v48/peyre16.pdf
    ab = a[:, None] * b[None, :]
    marginal_x = ab.sum(1)
    marginal_y = ab.sum(0)
    marginal_dep_term = _marginal_dependent_cost(marginal_x, marginal_y,
                                                 geom_x, geom_y, loss)

    tmp = geom_x.apply_cost(ab, axis=1, fn=loss.left_x)
    cost_matrix = marginal_dep_term - geom_y.apply_cost(
        tmp.T, axis=1, fn=loss.right_y).T
    return geometry.Geometry(cost_matrix=cost_matrix,
                             epsilon=epsilon,
                             **kwargs)
Beispiel #2
0
def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray,
                 tau_a: float, tau_b: float, f: jnp.ndarray,
                 g: jnp.ndarray) -> jnp.ndarray:
    """Computes objective of regularized OT given dual solutions f,g."""
    # In all sums below, jnp.where handle situations in which some coordinates of
    # a and b are zero. For those coordinates, their potential is -inf.
    # This leads to -inf - -inf or -inf x 0 operations which result in NaN.
    # These contributions are discarded when computing the objective.
    if tau_a == 1.0:
        div_a = jnp.sum(
            jnp.where(a > 0, (f - geom.potential_from_scaling(a)) * a, 0))
    else:
        rho_a = geom.epsilon * (tau_a / (1 - tau_a))
        div_a = jnp.sum(
            jnp.where(
                a > 0,
                a * (rho_a - (rho_a + geom.epsilon / 2) *
                     jnp.exp(-(f - geom.potential_from_scaling(a)) / rho_a)),
                0))

    if tau_b == 1.0:
        div_b = jnp.sum(
            jnp.where(b > 0, (g - geom.potential_from_scaling(b)) * b, 0))
    else:
        rho_b = geom.epsilon * (tau_b / (1 - tau_b))
        div_b = jnp.sum(
            jnp.where(
                b > 0,
                b * (rho_b - (rho_b + geom.epsilon / 2) *
                     jnp.exp(-(g - geom.potential_from_scaling(b)) / rho_b)),
                0))

    # Using https://arxiv.org/pdf/1910.12958.pdf Eq. 30
    return div_a + div_b + geom.epsilon * jnp.sum(a) * jnp.sum(b)
Beispiel #3
0
def marginal_error(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray,
                   tau_a: float, tau_b: float, f_u: jnp.ndarray,
                   g_v: jnp.ndarray, norm_error: int, lse_mode) -> jnp.ndarray:
    """Conputes marginal error, the stopping criterion used to terminate Sinkhorn.

  Args:
    geom: a Geometry object.
    a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights.
    b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights.
    tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    f_u: jnp.ndarray, potential or scaling
    g_v: jnp.ndarray, potential or scaling
    norm_error: int, p-norm used to compute error.
    lse_mode: True if log-sum-exp operations, False if kernel vector producs.

  Returns:
    a positive number quantifying how far from convergence the algorithm stands.

  """
    if tau_b == 1.0:
        err = geom.error(f_u, g_v, b, 0, norm_error, lse_mode)
    elif tau_a == 1.0:
        err = geom.error(f_u, g_v, a, 1, norm_error, lse_mode)
    else:
        # In the unbalanced case, we compute the norm of the gradient.
        # the gradient is equal to the marginal of the current plan minus
        # the gradient of < z, rho_z(exp^(-h/rho_z) -1> where z is either a or b
        # and h is either f or g. Note this is equal to z if rho_z → inf, which
        # is the case when tau_z → 1.0
        if lse_mode:
            target = grad_of_marginal_fit(a, b, f_u, g_v, tau_a, tau_b, geom)
        else:
            target = grad_of_marginal_fit(a, b,
                                          geom.potential_from_scaling(f_u),
                                          geom.potential_from_scaling(g_v),
                                          tau_a, tau_b, geom)
        err = geom.error(f_u, g_v, target[0], 1, norm_error, lse_mode)
        err += geom.error(f_u, g_v, target[1], 0, norm_error, lse_mode)
    return err
Beispiel #4
0
def discrete_barycenter(geom: geometry.Geometry,
                        a: jnp.ndarray,
                        weights: jnp.ndarray = None,
                        dual_initialization: jnp.ndarray = None,
                        threshold: float = 1e-2,
                        norm_error: int = 1,
                        inner_iterations: float = 10,
                        min_iterations: int = 0,
                        max_iterations: int = 2000,
                        lse_mode: bool = True,
                        debiased: bool = False) -> SinkhornBarycenterOutput:
  """Compute discrete barycenter using https://arxiv.org/abs/2006.02575.

  Args:
    geom: a Cost object able to apply kernels with a certain epsilon.
    a: jnp.ndarray<float>[batch, geom.num_a]: batch of histograms.
    weights: jnp.ndarray of weights in the probability simplex
    dual_initialization: jnp.ndarray, size [batch, num_b] initialization for g_v
    threshold: (float) tolerance to monitor convergence.
    norm_error: int, power used to define p-norm of error for marginal/target.
    inner_iterations: (int32) the Sinkhorn error is not recomputed at each
     iteration but every inner_num_iter instead to avoid computational overhead.
    min_iterations: (int32) the minimum number of Sinkhorn iterations carried
     out before the error is computed and monitored.
    max_iterations: (int32) the maximum number of Sinkhorn iterations.
    lse_mode: True for log-sum-exp computations, False for kernel multiply.
    debiased: whether to run the debiased version of the Sinkhorn divergence.

  Returns:
    A ``SinkhornBarycenterOutput``, which contains two arrays of potentials,
    each of size ``batch`` times ``geom.num_a``, summarizing the OT between each
    histogram in the database onto the barycenter, described in ``histogram``,
    as well as a sequence of errors that monitors convergence.
  """
  batch_size, num_a = a.shape
  _, num_b = geom.shape

  if weights is None:
    weights = jnp.ones((batch_size,)) / batch_size
  if not jnp.alltrue(weights > 0) or weights.shape[0] != batch_size:
    raise ValueError(f'weights must have positive values and size {batch_size}')

  if dual_initialization is None:
    # initialization strategy from https://arxiv.org/pdf/1503.02533.pdf, (3.6)
    dual_initialization = geom.apply_cost(a.T, axis=0).T
    dual_initialization -= jnp.average(dual_initialization,
                                       weights=weights,
                                       axis=0)[jnp.newaxis, :]

  if debiased and not geom.is_symmetric:
    raise ValueError('Geometry must be symmetric to use debiased option.')
  norm_error = (norm_error,)
  return _discrete_barycenter(geom, a, weights, dual_initialization, threshold,
                              norm_error, inner_iterations, min_iterations,
                              max_iterations, lse_mode, debiased, num_a, num_b)
Beispiel #5
0
def ent_reg_cost(geom: geometry.Geometry,
                 a: jnp.ndarray,
                 b: jnp.ndarray,
                 tau_a: float,
                 tau_b: float,
                 f: jnp.ndarray,
                 g: jnp.ndarray,
                 lse_mode: bool) -> jnp.ndarray:
  r"""Computes objective of regularized OT given dual solutions ``f``, ``g``.

  The objective is evaluated for dual solution ``f`` and ``g``, using inputs
  ``geom``, ``a`` and ``b``, in addition to parameters ``tau_a``, ``tau_b``.
  Situations where ``a`` or ``b`` have zero coordinates are reflected in
  minus infinity entries in their corresponding dual potentials. To avoid NaN
  that may result when multiplying 0's by infinity values, ``jnp.where`` is
  used to cancel these contributions.

  Args:
    geom: a Geometry object.
    a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights.
    b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights.
    tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    f: jnp.ndarray, potential
    g: jnp.ndarray, potential
    lse_mode: bool, whether to compute total mass in lse or kernel mode.

  Returns:
    a float, the regularized transport cost.
  """
  supp_a = a > 0
  supp_b = b > 0
  if tau_a == 1.0:
    div_a = jnp.sum(
        jnp.where(supp_a, a * (f - geom.potential_from_scaling(a)), 0.0))
  else:
    rho_a = geom.epsilon * (tau_a / (1 - tau_a))
    div_a = - jnp.sum(jnp.where(
        supp_a,
        a * phi_star(-(f - geom.potential_from_scaling(a)), rho_a),
        0.0))

  if tau_b == 1.0:
    div_b = jnp.sum(
        jnp.where(supp_b, b * (g - geom.potential_from_scaling(b)), 0.0))
  else:
    rho_b = geom.epsilon * (tau_b / (1 - tau_b))
    div_b = - jnp.sum(jnp.where(
        supp_b,
        b * phi_star(-(g - geom.potential_from_scaling(b)), rho_b),
        0.0))

  # Using https://arxiv.org/pdf/1910.12958.pdf (24)
  if lse_mode:
    total_sum = jnp.sum(geom.marginal_from_potentials(f, g))
  else:
    total_sum = jnp.sum(geom.marginal_from_scalings(
        geom.scaling_from_potential(f), geom.scaling_from_potential(g)))
  return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)
Beispiel #6
0
def _sinkhorn_iterations(
    tau_a: float,
    tau_b: float,
    inner_iterations: int,
    min_iterations: int,
    max_iterations: int,
    momentum_default: float,
    chg_momentum_from: int,
    lse_mode: bool,
    implicit_differentiation: bool,
    linear_solve_kwargs: Mapping[str, Union[Callable, float]],
    parallel_dual_updates: bool,
    init_dual_a: jnp.ndarray,
    init_dual_b: jnp.ndarray,
    threshold: float,
    norm_error: Sequence[int],
    geom: geometry.Geometry,
    a: jnp.ndarray,
    b: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
  """The jittable Sinkhorn loop, that uses a custom backward or not.

  Args:
    tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    inner_iterations: (int32) the Sinkhorn error is not recomputed at each
      iteration but every inner_num_iter instead.
    min_iterations: (int32) the minimum number of Sinkhorn iterations.
    max_iterations: (int32) the maximum number of Sinkhorn iterations.
    momentum_default: float, a float between ]0,2[
    chg_momentum_from: int, # of iterations after which momentum is computed
    lse_mode: True for log-sum-exp computations, False for kernel
      multiplication.
    implicit_differentiation: if True, do not backprop through the Sinkhorn
      loop, but use the implicit function theorem on the fixed point optimality
      conditions.
    linear_solve_kwargs: parameterization of linear solver when using implicit
      differentiation.
    parallel_dual_updates: updates potentials or scalings in parallel if True,
      sequentially (in Gauss-Seidel fashion) if False.
    init_dual_a: optional initialization for potentials/scalings w.r.t.
      first marginal (``a``) of reg-OT problem.
    init_dual_b: optional initialization for potentials/scalings w.r.t.
      second marginal (``b``) of reg-OT problem.
    threshold: (float) the relative threshold on the Sinkhorn error to stop the
      Sinkhorn iterations.
    norm_error: t-uple of int, p-norms of marginal / target errors to track
    geom: a Geometry object.
    a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights.
    b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights.

  Returns:
    f: potential
    g: potential
    errors: ndarray of errors
  """
  # Initializing solutions
  f_u, g_v = init_dual_a, init_dual_b

  # Delete arguments not used in forward pass.
  del linear_solve_kwargs

  # Defining the Sinkhorn loop, by setting initializations, body/cond.
  errors = -jnp.ones((np.ceil(max_iterations / inner_iterations).astype(int),
                      len(norm_error)))
  const = (geom, a, b, threshold)

  def cond_fn(iteration, const, state):
    threshold = const[-1]
    errors = state[0]
    err = errors[iteration // inner_iterations-1, 0]

    return jnp.logical_or(iteration == 0,
                          jnp.logical_and(jnp.isfinite(err), err > threshold))

  def get_momentum(errors, idx):
    """momentum formula, https://arxiv.org/pdf/2012.12562v1.pdf, p.7 and (5)."""
    error_ratio = jnp.minimum(errors[idx - 1, -1] / errors[idx - 2, -1], .99)
    power = 1.0 / inner_iterations
    return 2.0 / (1.0 + jnp.sqrt(1.0 - error_ratio ** power))

  def body_fn(iteration, const, state, compute_error):
    """Carries out sinkhorn iteration.

    Depending on lse_mode, these iterations can be either in:
      - log-space for numerical stability.
      - scaling space, using standard kernel-vector multiply operations.

    Args:
      iteration: iteration number
      const: tuple of constant parameters that do not change throughout the
        loop, here the geometry and the marginals a, b.
      state: potential/scaling variables updated in the loop & error log.
      compute_error: flag to indicate this iteration computes/stores an error

    Returns:
      state variables, i.e. errors and updated f_u, g_v potentials.
    """
    geom, a, b, _ = const
    errors, f_u, g_v = state

    # compute momentum term if needed, using previously seen errors.
    w = jax.lax.stop_gradient(jnp.where(iteration >= (
        inner_iterations * chg_momentum_from + min_iterations),
                                        get_momentum(errors, chg_momentum_from),
                                        momentum_default))

    # Sinkhorn updates using momentum, in either scaling or potential form.
    if parallel_dual_updates:
      old_g_v = g_v
    if lse_mode:
      new_g_v = tau_b * geom.update_potential(f_u, g_v, jnp.log(b),
                                              iteration, axis=0)
      g_v = (1.0 - w) * jnp.where(jnp.isfinite(g_v), g_v, 0.0) + w * new_g_v
      new_f_u = tau_a * geom.update_potential(
          f_u, old_g_v if parallel_dual_updates else g_v,
          jnp.log(a), iteration, axis=1)
      f_u = (1.0 - w) * jnp.where(jnp.isfinite(f_u), f_u, 0.0) + w * new_f_u
    else:
      new_g_v = geom.update_scaling(f_u, b, iteration, axis=0) ** tau_b
      g_v = jnp.where(g_v > 0, g_v, 1) ** (1.0 - w) * new_g_v ** w
      new_f_u = geom.update_scaling(
          old_g_v if parallel_dual_updates else g_v,
          a, iteration, axis=1) ** tau_a
      f_u = jnp.where(f_u > 0, f_u, 1) ** (1.0 - w) * new_f_u ** w

    # re-computes error if compute_error is True, else set it to inf.
    err = jnp.where(
        jnp.logical_and(compute_error, iteration >= min_iterations),
        marginal_error(geom, a, b, tau_a, tau_b, f_u, g_v, norm_error,
                       lse_mode),
        jnp.inf)

    errors = jax.ops.index_update(
        errors, jax.ops.index[iteration // inner_iterations, :], err)
    return errors, f_u, g_v

  # Run the Sinkhorn loop. choose either a standard fixpoint_iter loop if
  # differentiation is implicit, otherwise switch to the backprop friendly
  # version of that loop if unrolling to differentiate.

  if implicit_differentiation:
    fix_point = fixed_point_loop.fixpoint_iter
  else:
    fix_point = fixed_point_loop.fixpoint_iter_backprop

  errors, f_u, g_v = fix_point(
      cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, const,
      (errors, f_u, g_v))

  f = f_u if lse_mode else geom.potential_from_scaling(f_u)
  g = g_v if lse_mode else geom.potential_from_scaling(g_v)

  return f, g, errors[:, 0]
Beispiel #7
0
def _update_geometry_gw(geom: geometry.Geometry, geom_x: geometry.Geometry,
                        geom_y: geometry.Geometry, f: jnp.ndarray,
                        g: jnp.ndarray, loss: GWLoss,
                        **kwargs) -> geometry.Geometry:
    """Updates the geometry object for GW by updating the cost matrix.

  The cost matrix equation follows Equation 6, Proposition 1 of
  http://proceedings.mlr.press/v48/peyre16.pdf.

  Let :math:`p` [num_a,] be the marginal of the transport matrix for samples
  from geom_x and :math:`q` [num_b,] be the marginal of the transport matrix for
  samples from geom_y. Let :math:`T` [num_a, num_b] be the transport matrix.
  The cost matrix equation can be written as:

  cost_matrix = marginal_dep_term
              + left_x(cost_x) :math:`T` right_y(cost_y):math:`^T`

  Args:
    geom: a Geometry object carrying the cost matrix of Gromov Wasserstein.
    geom_x: a Geometry object for the first view.
    geom_y: a second Geometry object for the second view.
    f: jnp.ndarray<float>[num_a,], potentials.
    g: jnp.ndarray<float>[num_b,], potentials.
    loss: a GWLossFn object.
    **kwargs: additional kwargs for epsilon.

  Returns:
    A Geometry object for Gromov-Wasserstein.
  """
    def apply_cost_fn(geom):
        condition = is_sqeuclidean(geom) and isinstance(loss, GWSqEuclLoss)
        return geom.vec_apply_cost if condition else geom.apply_cost

    def is_sqeuclidean(geom):
        return (isinstance(geom, pointcloud.PointCloud) and geom.power == 2.0
                and isinstance(geom._cost_fn, costs.Euclidean))

    def is_online(geom):
        return isinstance(geom, pointcloud.PointCloud) and geom._online

    # Computes tmp = cost_matrix_x * transport
    if is_online(geom_x) or is_sqeuclidean(geom_x):
        transport = geom.transport_from_potentials(f, g)
        tmp = apply_cost_fn(geom_x)(transport, axis=1, fn=loss.left_x)
    else:
        tmp = geom.apply_transport_from_potentials(f,
                                                   g,
                                                   loss.left_x(
                                                       geom_x.cost_matrix),
                                                   axis=0)

    # Computes cost_matrix
    marginal_x = geom.marginal_from_potentials(f, g, axis=1)
    marginal_y = geom.marginal_from_potentials(f, g, axis=0)
    marginal_dep_term = _marginal_dependent_cost(marginal_x, marginal_y,
                                                 geom_x, geom_y, loss)
    cost_matrix = marginal_dep_term - apply_cost_fn(geom_y)(
        tmp.T, axis=1, fn=loss.right_y).T
    return geometry.Geometry(cost_matrix=cost_matrix,
                             epsilon=geom._epsilon,
                             **kwargs)
Beispiel #8
0
def _discrete_barycenter(geom: geometry.Geometry,
                         a: jnp.ndarray,
                         weights: jnp.ndarray,
                         dual_initialization: jnp.ndarray,
                         threshold: float,
                         norm_error: Sequence[int],
                         inner_iterations: int,
                         min_iterations: int,
                         max_iterations: int,
                         lse_mode: bool,
                         debiased: bool,
                         num_a: int,
                         num_b: int) -> SinkhornBarycenterOutput:
  """Jit'able function to compute discrete barycenters."""
  if lse_mode:
    f_u = jnp.zeros_like(a)
    g_v = dual_initialization
  else:
    f_u = jnp.ones_like(a)
    g_v = geom.scaling_from_potential(dual_initialization)
  # d below is as described in https://arxiv.org/abs/2006.02575. Note that
  # d should be considered to be equal to eps log(d) with those notations
  # if running in log-sum-exp mode.
  d = jnp.zeros((num_b,)) if lse_mode else jnp.ones((num_b,))

  if lse_mode:
    parallel_update = jax.vmap(
        lambda f, g, marginal, iter: geom.update_potential(
            f, g, jnp.log(marginal), axis=1),
        in_axes=[0, 0, 0, None])
    parallel_apply = jax.vmap(
        lambda f_, g_, eps_: geom.apply_lse_kernel(
            f_, g_, eps_, vec=None, axis=0)[0],
        in_axes=[0, 0, None])
  else:
    parallel_update = jax.vmap(
        lambda f, g, marginal, iter: geom.update_scaling(g, marginal, axis=1),
        in_axes=[0, 0, 0, None])
    parallel_apply = jax.vmap(
        lambda f_, g_, eps_: geom.apply_kernel(f_, eps_, axis=0),
        in_axes=[0, 0, None])

  errors_fn = jax.vmap(
      functools.partial(geom.error, axis=1, norm_error=norm_error,
                        lse_mode=lse_mode),
      in_axes=[0, 0, 0])

  errors = - jnp.ones(
      (max_iterations // inner_iterations + 1, len(norm_error)))

  const = (geom, a, weights)
  def cond_fn(iteration, const, state):  # pylint: disable=unused-argument
    errors = state[0]
    return jnp.logical_or(
        iteration == 0,
        errors[iteration // inner_iterations - 1, 0] > threshold)

  def body_fn(iteration, const, state, compute_error):
    geom, a, weights = const
    errors, d, f_u, g_v = state

    eps = geom._epsilon.at(iteration)  # pylint: disable=protected-access
    f_u = parallel_update(f_u, g_v, a, iteration)
    # kernel_f_u stands for K times potential u if running in scaling mode,
    # eps log K exp f / eps in lse mode.
    kernel_f_u = parallel_apply(f_u, g_v, eps)
    # b below is the running estimate for the barycenter if running in scaling
    # mode, eps log b if running in lse mode.
    if lse_mode:
      b = jnp.average(kernel_f_u, weights=weights, axis=0)
    else:
      b = jnp.prod(kernel_f_u ** weights[:, jnp.newaxis], axis=0)

    if debiased:
      if lse_mode:
        b += d
        d = 0.5 * (
            d +
            geom.update_potential(jnp.zeros((num_a,)), d,
                                  b / eps,
                                  iteration=iteration, axis=0))
      else:
        b *= d
        d = jnp.sqrt(
            d *
            geom.update_scaling(d, b,
                                iteration=iteration, axis=0))
    if lse_mode:
      g_v = b[jnp.newaxis, :] - kernel_f_u
    else:
      g_v = b[jnp.newaxis, :] / kernel_f_u

    # re-compute error if compute_error is True, else set to inf.
    err = jnp.where(
        jnp.logical_and(compute_error, iteration >= min_iterations),
        jnp.mean(errors_fn(f_u, g_v, a)),
        jnp.inf)

    errors = jax.ops.index_update(
        errors, jax.ops.index[iteration // inner_iterations, :], err)
    return errors, d, f_u, g_v

  state = (errors, d, f_u, g_v)

  state = fixed_point_loop.fixpoint_iter_backprop(cond_fn, body_fn,
                                                  min_iterations,
                                                  max_iterations,
                                                  inner_iterations, const,
                                                  state)

  errors, d, f_u, g_v = state
  kernel_f_u = parallel_apply(f_u, g_v, geom.epsilon)
  if lse_mode:
    b = jnp.average(kernel_f_u, weights=weights, axis=0)
  else:
    b = jnp.prod(kernel_f_u ** weights[:, jnp.newaxis], axis=0)

  if debiased:
    if lse_mode:
      b += d
    else:
      b *= d
  if lse_mode:
    b = jnp.exp(b / geom.epsilon)
  return SinkhornBarycenterOutput(f_u, g_v, b, errors)
Beispiel #9
0
def ent_reg_cost(geom: geometry.Geometry, a: jnp.ndarray, b: jnp.ndarray,
                 tau_a: float, tau_b: float, f: jnp.ndarray,
                 g: jnp.ndarray) -> jnp.ndarray:
    """Computes objective of regularized OT given dual solutions f,g.

  In all sums below, jnp.where handle situations in which some coordinates of
  a and b are zero. For those coordinates, their potential is -inf.
  This leads to -inf - -inf or -inf x 0 operations which result in NaN.
  These contributions are discarded when computing the objective.

  Args:
    geom: a Geometry object.
    a: jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights.
    b: jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights.
    tau_a: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    tau_b: float, ratio lam/(lam+eps) between KL divergence regularizer to first
      marginal and itself + epsilon regularizer used in the unbalanced
      formulation.
    f: jnp.ndarray, potential
    g: jnp.ndarray, potential

  Returns:
    a float, the regularized transport cost.
  """

    if tau_a == 1.0:
        div_a = jnp.sum(
            jnp.where(a > 0, (f - geom.potential_from_scaling(a)) * a, 0.0))
    else:
        rho_a = geom.epsilon * (tau_a / (1 - tau_a))
        div_a = jnp.sum(
            jnp.where(
                a > 0,
                a * (rho_a - (rho_a + geom.epsilon / 2) *
                     jnp.exp(-(f - geom.potential_from_scaling(a)) / rho_a)),
                0.0))

    if tau_b == 1.0:
        div_b = jnp.sum(
            jnp.where(b > 0, (g - geom.potential_from_scaling(b)) * b, 0.0))
    else:
        rho_b = geom.epsilon * (tau_b / (1 - tau_b))
        div_b = jnp.sum(
            jnp.where(
                b > 0,
                b * (rho_b - (rho_b + geom.epsilon / 2) *
                     jnp.exp(-(g - geom.potential_from_scaling(b)) / rho_b)),
                0.0))

    # Using https://arxiv.org/pdf/1910.12958.pdf (30), corrected with (15)
    # The total mass of the coupling is computed in scaling space. This avoids
    # differentiation issues linked with the automatic differention of
    # jnp.exp(jnp.logsumexp(...)) when some of those logs appear as -inf.
    # Because we are computing total mass it is irrelevant to have underflow since
    # this would simply result in near 0 contributions, which, unlike Sinkhorn
    # iterations, do not appear next in a numerator.
    total_sum = jnp.sum(
        geom.marginal_from_scalings(geom.scaling_from_potential(f),
                                    geom.scaling_from_potential(g)))
    return div_a + div_b + geom.epsilon * (jnp.sum(a) * jnp.sum(b) - total_sum)