def kernel_keops(kernel, α, x, β, y, potentials=False, ranges_xx=None, ranges_yy=None, ranges_xy=None): D = x.shape[1] kernel_conv = generic_sum( "(" + kernel + " * B)", # Formula "A = Vi(1)", # Output: a_i "X = Vi({})".format(D), # 1st input: x_i "Y = Vj({})".format(D), # 2nd input: y_j "B = Vj(1)") # 3rd input: b_j a_x = kernel_conv(double_grad(x), x.detach(), α.detach().view(-1, 1), ranges=ranges_xx) b_y = kernel_conv(double_grad(y), y.detach(), β.detach().view(-1, 1), ranges=ranges_yy) b_x = kernel_conv(x, y, β.view(-1, 1), ranges=ranges_xy) if potentials: a_y = kernel_conv(y, x, α.view(-1, 1), ranges=swap_axes(ranges_xy)) return a_x - b_x, b_y - a_y else: # Return the Kernel norm. N.B.: we assume that 'kernel' is symmetric: return .5 * scal( double_grad(α), a_x ) \ + .5 * scal( double_grad(β), b_y ) - scal( α, b_x )
def kernel_truncation(C_xy, C_yx, C_xy_, C_yx_, b_x, a_y, ε, truncate=None, cost=None, verbose=False): """Prunes out useless parts of the (block-sparse) cost matrices for finer scales. This is where our approximation takes place. To be mathematically rigorous, we should make several coarse-to-fine passes, making sure that we're not forgetting anyone. A good reference here is Bernhard Schmitzer's work: "Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems, (2016)". """ if truncate is None: return C_xy_, C_yx_ else: x, yd, ranges_x, ranges_y, _ = C_xy y, xd, _, _, _ = C_yx x_, yd_, ranges_x_, ranges_y_, _ = C_xy_ y_, xd_, _, _, _ = C_yx_ with torch.no_grad(): C = cost(x, y) keep = b_x.view(-1, 1) + a_y.view(1, -1) > C - truncate * ε ranges_xy_ = from_matrix(ranges_x, ranges_y, keep) if verbose: ks, Cs = keep.sum(), C.shape[0] * C.shape[1] print( "Keep {}/{} = {:2.1f}% of the coarse cost matrix.".format( ks, Cs, 100 * float(ks) / Cs)) return (x_, yd_, ranges_x_, ranges_y_, ranges_xy_), ( y_, xd_, ranges_y_, ranges_x_, swap_axes(ranges_xy_), )