Example #1
0
def kernel_keops(kernel,
                 α,
                 x,
                 β,
                 y,
                 potentials=False,
                 ranges_xx=None,
                 ranges_yy=None,
                 ranges_xy=None):

    D = x.shape[1]
    kernel_conv = generic_sum(
        "(" + kernel + " * B)",  # Formula
        "A = Vi(1)",  # Output:    a_i
        "X = Vi({})".format(D),  # 1st input: x_i
        "Y = Vj({})".format(D),  # 2nd input: y_j
        "B = Vj(1)")  # 3rd input: b_j

    a_x = kernel_conv(double_grad(x),
                      x.detach(),
                      α.detach().view(-1, 1),
                      ranges=ranges_xx)
    b_y = kernel_conv(double_grad(y),
                      y.detach(),
                      β.detach().view(-1, 1),
                      ranges=ranges_yy)
    b_x = kernel_conv(x, y, β.view(-1, 1), ranges=ranges_xy)

    if potentials:
        a_y = kernel_conv(y, x, α.view(-1, 1), ranges=swap_axes(ranges_xy))
        return a_x - b_x, b_y - a_y

    else:  # Return the Kernel norm. N.B.: we assume that 'kernel' is symmetric:
        return .5 * scal( double_grad(α), a_x ) \
             + .5 * scal( double_grad(β), b_y )  -  scal( α, b_x )
Example #2
0
def kernel_truncation(C_xy,
                      C_yx,
                      C_xy_,
                      C_yx_,
                      b_x,
                      a_y,
                      ε,
                      truncate=None,
                      cost=None,
                      verbose=False):
    """Prunes out useless parts of the (block-sparse) cost matrices for finer scales.

    This is where our approximation takes place.
    To be mathematically rigorous, we should make several coarse-to-fine passes,
    making sure that we're not forgetting anyone. A good reference here is
    Bernhard Schmitzer's work: "Stabilized Sparse Scaling Algorithms for
    Entropy Regularized Transport Problems, (2016)".
    """
    if truncate is None:
        return C_xy_, C_yx_
    else:
        x, yd, ranges_x, ranges_y, _ = C_xy
        y, xd, _, _, _ = C_yx
        x_, yd_, ranges_x_, ranges_y_, _ = C_xy_
        y_, xd_, _, _, _ = C_yx_

        with torch.no_grad():
            C = cost(x, y)
            keep = b_x.view(-1, 1) + a_y.view(1, -1) > C - truncate * ε
            ranges_xy_ = from_matrix(ranges_x, ranges_y, keep)
            if verbose:
                ks, Cs = keep.sum(), C.shape[0] * C.shape[1]
                print(
                    "Keep {}/{} = {:2.1f}% of the coarse cost matrix.".format(
                        ks, Cs, 100 * float(ks) / Cs))

        return (x_, yd_, ranges_x_, ranges_y_, ranges_xy_), (
            y_,
            xd_,
            ranges_y_,
            ranges_x_,
            swap_axes(ranges_xy_),
        )