def compute(kernel_name, x, y): if kernel_name == 'hik': return kernels.histogram_intersection(x, y) elif kernel_name == 'linear': return kernels.linear(x, y) elif kernel_name == 'chi2': return kernels.chi_square(x, y) else: raise ValueError(kernel_name)
def dist_spm(a, b): d = lambda x, y: kernels.histogram_intersection(x.reshape((1, -1)), y.reshape(1, -1)).flat[0] r = lambda x: resize_mask(x, x.shape[0] / 2) out = 0. weight = 1. while a.size != 1: out += d(a, b) * weight a, b = r(a), r(b) weight /= 4 return out