def linear_mmd_test(X, Y, null_samples=1000): mmd = sg.QuadraticTimeMMD() mmd.set_p(sg.RealFeatures(X.T.astype(np.float64))) mmd.set_q(sg.RealFeatures(Y.T.astype(np.float64))) mmd.set_kernel(sg.LinearKernel()) mmd.set_num_null_samples(null_samples) samps = mmd.sample_null() stat = mmd.compute_statistic() p_val = np.mean(stat <= samps) return p_val, stat, samps
def shogun_mmd(X, Y, kernel_width, null_samples=1000, median_samples=1000, cache_size=32): ''' Run an MMD test using a Gaussian kernel. Parameters ---------- X : row-instance feature array Y : row-instance feature array kernel_width : float The bandwidth of the RBF kernel (sigma). null_samples : int How many times to sample from the null distribution. Returns ------- p_val : float The obtained p value of the test. stat : float The test statistic. null_samples : array of length null_samples The samples from the null distribution. ''' import modshogun as sg mmd = sg.QuadraticTimeMMD() mmd.set_p(sg.RealFeatures(X.T.astype(np.float64))) mmd.set_q(sg.RealFeatures(Y.T.astype(np.float64))) mmd.set_kernel(sg.GaussianKernel(cache_size, float(kernel_width))) mmd.set_num_null_samples(null_samples) samps = mmd.sample_null() stat = mmd.compute_statistic() p_val = np.mean(stat <= samps) return p_val, stat, samps
import numpy as np import modshogun as sg X = np.random.randn(100, 3) Y = np.random.randn(100, 3) + .5 mmd = sg.QuadraticTimeMMD() mmd.set_p(sg.RealFeatures(X.T)) mmd.set_q(sg.RealFeatures(Y.T)) mmd.set_kernel(sg.GaussianKernel(32, 1)) mmd.set_num_null_samples(200) samps = mmd.sample_null() stat = mmd.compute_statistic()
def rbf_mmd_test(X, Y, bandwidth='median', null_samples=1000, median_samples=1000, cache_size=32): ''' Run an MMD test using a Gaussian kernel. Parameters ---------- X : row-instance feature array Y : row-instance feature array bandwidth : float or 'median' The bandwidth of the RBF kernel (sigma). If 'median', estimates the median pairwise distance in the aggregate sample and uses that. null_samples : int How many times to sample from the null distribution. median_samples : int How many points to use for estimating the bandwidth. Returns ------- p_val : float The obtained p value of the test. stat : float The test statistic. null_samples : array of length null_samples The samples from the null distribution. bandwidth : float The used kernel bandwidth ''' if bandwidth == 'median': from sklearn.metrics.pairwise import euclidean_distances sub = lambda feats, n: feats[np.random.choice( feats.shape[0], min(feats.shape[0], n), replace=False)] Z = np.r_[sub(X, median_samples // 2), sub(Y, median_samples // 2)] D2 = euclidean_distances(Z, squared=True) upper = D2[np.triu_indices_from(D2, k=1)] kernel_width = np.median(upper, overwrite_input=True) bandwidth = np.sqrt(kernel_width / 2) # sigma = median / sqrt(2); works better, sometimes at least del Z, D2, upper else: kernel_width = 2 * bandwidth**2 mmd = sg.QuadraticTimeMMD() mmd.set_p(sg.RealFeatures(X.T.astype(np.float64))) mmd.set_q(sg.RealFeatures(Y.T.astype(np.float64))) mmd.set_kernel(sg.GaussianKernel(cache_size, kernel_width)) mmd.set_num_null_samples(null_samples) samps = mmd.sample_null() stat = mmd.compute_statistic() p_val = np.mean(stat <= samps) return p_val, stat, samps, bandwidth
def get_estimates(gen, sigmas=None, n_reps=100, n_null_samps=1000, cache_size=64, rep_states=False, name=None, save_samps=False, thresh_levels=(.2, .1, .05, .01)): if sigmas is None: sigmas = np.logspace(-1.7, 1.7, num=30) sigmas = np.asarray(sigmas) mmd = sg.QuadraticTimeMMD() mmd.set_num_null_samples(n_null_samps) mmd_mk = mmd.multikernel() for s in sigmas: mmd_mk.add_kernel(sg.GaussianKernel(cache_size, 2 * s**2)) info = OrderedDict() for k in 'sigma rep mmd_est var_est p'.split(): info[k] = [] thresh_names = [] for l in thresh_levels: s = 'thresh_{}'.format(l) thresh_names.append(s) info[s] = [] if save_samps: info['samps'] = [] thresh_prob = 1 - np.asarray(thresh_levels) bar = pb.ProgressBar() if name is not None: bar.start() bar.widgets.insert(0, '{} '.format(name)) for rep in bar(xrange(n_reps)): if rep_states: rep = np.random.randint(0, 2**32) X, Y = gen(rs=rep) else: X, Y = gen() n = X.shape[0] assert Y.shape[0] == n mmd.set_p(sg.RealFeatures(X.T)) mmd.set_q(sg.RealFeatures(Y.T)) info['sigma'].extend(sigmas) info['rep'].extend([rep] * len(sigmas)) stat = mmd_mk.compute_statistic() info['mmd_est'].extend(stat / (n / 2)) samps = mmd_mk.sample_null() info['p'].extend(np.mean(samps >= stat, axis=0)) if save_samps: info['samps'].extend(samps.T) info['var_est'].extend(mmd_mk.compute_variance_h1()) threshes = np.asarray(mquantiles(samps, prob=thresh_prob, axis=0)) for s, t in zip(thresh_names, threshes): info[s].extend(t) info = pd.DataFrame(info) info.set_index(['sigma', 'rep'], inplace=True) return info