def statistics_quadratic_time_mmd(m, dim, difference):
    from modshogun import RealFeatures
    from modshogun import MeanShiftDataGenerator
    from modshogun import GaussianKernel, CustomKernel
    from modshogun import QuadraticTimeMMD
    from modshogun import BOOTSTRAP, MMD2_SPECTRUM, MMD2_GAMMA, BIASED, UNBIASED
    from modshogun import Statistics, IntVector, RealVector, Math

    # init seed for reproducability
    Math.init_random(1)
    random.seed(17)

    # number of examples kept low in order to make things fast

    # streaming data generator for mean shift distributions
    gen_p = MeanShiftDataGenerator(0, dim)
    #gen_p.parallel.set_num_threads(1)
    gen_q = MeanShiftDataGenerator(difference, dim)

    # stream some data from generator
    feat_p = gen_p.get_streamed_features(m)
    feat_q = gen_q.get_streamed_features(m)

    # set kernel a-priori. usually one would do some kernel selection. See
    # other examples for this.
    width = 10
    kernel = GaussianKernel(10, width)

    # create quadratic time mmd instance. Note that this constructor
    # copies p and q and does not reference them
    mmd = QuadraticTimeMMD(kernel, feat_p, feat_q)

    # perform test: compute p-value and test if null-hypothesis is rejected for
    # a test level of 0.05
    alpha = 0.05

    # using bootstrapping (slow, not the most reliable way. Consider pre-
    # computing the kernel when using it, see below).
    # Also, in practice, use at least 250 iterations
    mmd.set_null_approximation_method(BOOTSTRAP)
    mmd.set_bootstrap_iterations(3)
    p_value_boot = mmd.perform_test()
    # reject if p-value is smaller than test level
    #print "bootstrap: p!=q: ", p_value_boot<alpha

    # using spectrum method. Use at least 250 samples from null.
    # This is consistent but sometimes breaks, always monitor type I error.
    # See tutorial for number of eigenvalues to use .
    # Only works with BIASED statistic
    mmd.set_statistic_type(BIASED)
    mmd.set_null_approximation_method(MMD2_SPECTRUM)
    mmd.set_num_eigenvalues_spectrum(3)
    mmd.set_num_samples_sepctrum(250)
    p_value_spectrum = mmd.perform_test()
    # reject if p-value is smaller than test level
    #print "spectrum: p!=q: ", p_value_spectrum<alpha

    # using gamma method. This is a quick hack, which works most of the time
    # but is NOT guaranteed to. See tutorial for details.
    # Only works with BIASED statistic
    mmd.set_statistic_type(BIASED)
    mmd.set_null_approximation_method(MMD2_GAMMA)
    p_value_gamma = mmd.perform_test()
    # reject if p-value is smaller than test level
    #print "gamma: p!=q: ", p_value_gamma<alpha

    # compute tpye I and II error (use many more trials in practice).
    # Type I error is not necessary if one uses bootstrapping. We do it here
    # anyway, but note that this is an efficient way of computing it.
    # Also note that testing has to happen on
    # difference data than kernel selection, but the linear time mmd does this
    # implicitly and we used a fixed kernel here.
    mmd.set_null_approximation_method(BOOTSTRAP)
    mmd.set_bootstrap_iterations(5)
    num_trials = 5
    type_I_errors = RealVector(num_trials)
    type_II_errors = RealVector(num_trials)
    inds = int32(array([x for x in range(2 * m)]))  # numpy
    p_and_q = mmd.get_p_and_q()

    # use a precomputed kernel to be faster
    kernel.init(p_and_q, p_and_q)
    precomputed = CustomKernel(kernel)
    mmd.set_kernel(precomputed)
    for i in range(num_trials):
        # this effectively means that p=q - rejecting is tpye I error
        inds = random.permutation(inds)  # numpy permutation
        precomputed.add_row_subset(inds)
        precomputed.add_col_subset(inds)
        type_I_errors[i] = mmd.perform_test() > alpha
        precomputed.remove_row_subset()
        precomputed.remove_col_subset()

        # on normal data, this gives type II error
        type_II_errors[i] = mmd.perform_test() > alpha

    return type_I_errors.get(), type_I_errors.get(
    ), p_value_boot, p_value_spectrum, p_value_gamma,
def statistics_quadratic_time_mmd (m,dim,difference):
	from modshogun import RealFeatures
	from modshogun import MeanShiftDataGenerator
	from modshogun import GaussianKernel, CustomKernel
	from modshogun import QuadraticTimeMMD
	from modshogun import PERMUTATION, MMD2_SPECTRUM, MMD2_GAMMA, BIASED, BIASED_DEPRECATED
	from modshogun import Statistics, IntVector, RealVector, Math

	# init seed for reproducability
	Math.init_random(1)
	random.seed(17)

	# number of examples kept low in order to make things fast

	# streaming data generator for mean shift distributions
	gen_p=MeanShiftDataGenerator(0, dim);
	#gen_p.parallel.set_num_threads(1)
	gen_q=MeanShiftDataGenerator(difference, dim);

	# stream some data from generator
	feat_p=gen_p.get_streamed_features(m);
	feat_q=gen_q.get_streamed_features(m);

	# set kernel a-priori. usually one would do some kernel selection. See
	# other examples for this.
	width=10;
	kernel=GaussianKernel(10, width);

	# create quadratic time mmd instance. Note that this constructor
	# copies p and q and does not reference them
	mmd=QuadraticTimeMMD(kernel, feat_p, feat_q);

	# perform test: compute p-value and test if null-hypothesis is rejected for
	# a test level of 0.05
	alpha=0.05;

	# using permutation (slow, not the most reliable way. Consider pre-
	# computing the kernel when using it, see below).
	# Also, in practice, use at least 250 iterations
	mmd.set_null_approximation_method(PERMUTATION);
	mmd.set_num_null_samples(3);
	p_value_null=mmd.perform_test();
	# reject if p-value is smaller than test level
	#print "bootstrap: p!=q: ", p_value_null<alpha

	# using spectrum method. Use at least 250 samples from null.
	# This is consistent but sometimes breaks, always monitor type I error.
	# See tutorial for number of eigenvalues to use .
	mmd.set_statistic_type(BIASED);
	mmd.set_null_approximation_method(MMD2_SPECTRUM);
	mmd.set_num_eigenvalues_spectrum(3);
	mmd.set_num_samples_spectrum(250);
	p_value_spectrum=mmd.perform_test();
	# reject if p-value is smaller than test level
	#print "spectrum: p!=q: ", p_value_spectrum<alpha

	# using gamma method. This is a quick hack, which works most of the time
	# but is NOT guaranteed to. See tutorial for details.
	# Only works with BIASED_DEPRECATED statistic
	mmd.set_statistic_type(BIASED_DEPRECATED);
	mmd.set_null_approximation_method(MMD2_GAMMA);
	p_value_gamma=mmd.perform_test();
	# reject if p-value is smaller than test level
	#print "gamma: p!=q: ", p_value_gamma<alpha

	# compute tpye I and II error (use many more trials in practice).
	# Type I error is not necessary if one uses permutation. We do it here
	# anyway, but note that this is an efficient way of computing it.
	# Also note that testing has to happen on
	# difference data than kernel selection, but the linear time mmd does this
	# implicitly and we used a fixed kernel here.
	mmd.set_statistic_type(BIASED);
	mmd.set_null_approximation_method(PERMUTATION);
	mmd.set_num_null_samples(5);
	num_trials=5;
	type_I_errors=RealVector(num_trials);
	type_II_errors=RealVector(num_trials);
	inds=int32(array([x for x in range(2*m)])) # numpy
	p_and_q=mmd.get_p_and_q();

	# use a precomputed kernel to be faster
	kernel.init(p_and_q, p_and_q);
	precomputed=CustomKernel(kernel);
	mmd.set_kernel(precomputed);
	for i in range(num_trials):
		# this effectively means that p=q - rejecting is tpye I error
		inds=random.permutation(inds) # numpy permutation
		precomputed.add_row_subset(inds);
		precomputed.add_col_subset(inds);
		type_I_errors[i]=mmd.perform_test()>alpha;
		precomputed.remove_row_subset();
		precomputed.remove_col_subset();

		# on normal data, this gives type II error
		type_II_errors[i]=mmd.perform_test()>alpha;

	return type_I_errors.get(),type_I_errors.get(),p_value_null,p_value_spectrum,p_value_gamma,
def quadratic_time_mmd_graphical():

	# parameters, change to get different results
	m=100
	dim=2

	# setting the difference of the first dimension smaller makes a harder test
	difference=0.5

	# number of samples taken from null and alternative distribution
	num_null_samples=500

	# streaming data generator for mean shift distributions
	gen_p=MeanShiftDataGenerator(0, dim)
	gen_q=MeanShiftDataGenerator(difference, dim)

	# Stream examples and merge them in order to compute MMD on joint sample
	# alternative is to call a different constructor of QuadraticTimeMMD
	features=gen_p.get_streamed_features(m)
	features=features.create_merged_copy(gen_q.get_streamed_features(m))

	# use the median kernel selection
	# create combined kernel with Gaussian kernels inside (shoguns Gaussian kernel is
	# compute median data distance in order to use for Gaussian kernel width
	# 0.5*median_distance normally (factor two in Gaussian kernel)
	# However, shoguns kernel width is different to usual parametrization
	# Therefore 0.5*2*median_distance^2
	# Use a subset of data for that, only 200 elements. Median is stable
	sigmas=[2**x for x in range(-3,10)]
	widths=[x*x*2 for x in sigmas]
	print "kernel widths:", widths
	combined=CombinedKernel()
	for i in range(len(sigmas)):
		combined.append_kernel(GaussianKernel(10, widths[i]))

	# create MMD instance, use biased statistic
	mmd=QuadraticTimeMMD(combined,features, m)
	mmd.set_statistic_type(BIASED)

	# kernel selection instance (this can easily replaced by the other methods for selecting
	# single kernels
	selection=MMDKernelSelectionMax(mmd)

	# perform kernel selection
	kernel=selection.select_kernel()
	kernel=GaussianKernel.obtain_from_generic(kernel)
	mmd.set_kernel(kernel);
	print "selected kernel width:", kernel.get_width()

	# sample alternative distribution (new data each trial)
	alt_samples=zeros(num_null_samples)
	for i in range(len(alt_samples)):
		# Stream examples and merge them in order to replace in MMD
		features=gen_p.get_streamed_features(m)
		features=features.create_merged_copy(gen_q.get_streamed_features(m))
		mmd.set_p_and_q(features)
		alt_samples[i]=mmd.compute_statistic()

	# sample from null distribution
	# bootstrapping, biased statistic
	mmd.set_null_approximation_method(BOOTSTRAP)
	mmd.set_statistic_type(BIASED)
	mmd.set_bootstrap_iterations(num_null_samples)
	null_samples_boot=mmd.bootstrap_null()

	# sample from null distribution
	# spectrum, biased statistic
	if "sample_null_spectrum" in dir(QuadraticTimeMMD):
			mmd.set_null_approximation_method(MMD2_SPECTRUM)
			mmd.set_statistic_type(BIASED)
			null_samples_spectrum=mmd.sample_null_spectrum(num_null_samples, m-10)

	# fit gamma distribution, biased statistic
	mmd.set_null_approximation_method(MMD2_GAMMA)
	mmd.set_statistic_type(BIASED)
	gamma_params=mmd.fit_null_gamma()
	# sample gamma with parameters
	null_samples_gamma=array([gamma(gamma_params[0], gamma_params[1]) for _ in range(num_null_samples)])

	# to plot data, sample a few examples from stream first
	features=gen_p.get_streamed_features(m)
	features=features.create_merged_copy(gen_q.get_streamed_features(m))
	data=features.get_feature_matrix()

	# plot
	figure()
	title('Quadratic Time MMD')

	# plot data of p and q
	subplot(2,3,1)
	grid(True)
	gca().xaxis.set_major_locator( MaxNLocator(nbins = 4) ) # reduce number of x-ticks
	gca().yaxis.set_major_locator( MaxNLocator(nbins = 4) ) # reduce number of x-ticks
	plot(data[0][0:m], data[1][0:m], 'ro', label='$x$')
	plot(data[0][m+1:2*m], data[1][m+1:2*m], 'bo', label='$x$', alpha=0.5)
	title('Data, shift in $x_1$='+str(difference)+'\nm='+str(m))
	xlabel('$x_1, y_1$')
	ylabel('$x_2, y_2$')

	# histogram of first data dimension and pdf
	subplot(2,3,2)
	grid(True)
	gca().xaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
	gca().yaxis.set_major_locator( MaxNLocator(nbins = 3 )) # reduce number of x-ticks
	hist(data[0], bins=50, alpha=0.5, facecolor='r', normed=True)
	hist(data[1], bins=50, alpha=0.5, facecolor='b', normed=True)
	xs=linspace(min(data[0])-1,max(data[0])+1, 50)
	plot(xs,normpdf( xs, 0, 1), 'r', linewidth=3)
	plot(xs,normpdf( xs, difference, 1), 'b', linewidth=3)
	xlabel('$x_1, y_1$')
	ylabel('$p(x_1), p(y_1)$')
	title('Data PDF in $x_1, y_1$')

	# compute threshold for test level
	alpha=0.05
	null_samples_boot.sort()
	null_samples_spectrum.sort()
	null_samples_gamma.sort()
	thresh_boot=null_samples_boot[floor(len(null_samples_boot)*(1-alpha))];
	thresh_spectrum=null_samples_spectrum[floor(len(null_samples_spectrum)*(1-alpha))];
	thresh_gamma=null_samples_gamma[floor(len(null_samples_gamma)*(1-alpha))];

	type_one_error_boot=sum(null_samples_boot<thresh_boot)/float(num_null_samples)
	type_one_error_spectrum=sum(null_samples_spectrum<thresh_boot)/float(num_null_samples)
	type_one_error_gamma=sum(null_samples_gamma<thresh_boot)/float(num_null_samples)

	# plot alternative distribution with threshold
	subplot(2,3,4)
	grid(True)
	gca().xaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
	gca().yaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
	hist(alt_samples, 20, normed=True);
	axvline(thresh_boot, 0, 1, linewidth=2, color='red')
	type_two_error=sum(alt_samples<thresh_boot)/float(num_null_samples)
	title('Alternative Dist.\n' + 'Type II error is ' + str(type_two_error))

	# compute range for all null distribution histograms
	hist_range=[min([min(null_samples_boot), min(null_samples_spectrum), min(null_samples_gamma)]), max([max(null_samples_boot), max(null_samples_spectrum), max(null_samples_gamma)])]

	# plot null distribution with threshold
	subplot(2,3,3)
	gca().xaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
	gca().yaxis.set_major_locator( MaxNLocator(nbins = 3 )) # reduce number of x-ticks
	hist(null_samples_boot, 20, range=hist_range, normed=True);
	axvline(thresh_boot, 0, 1, linewidth=2, color='red')
	title('Bootstrapped Null Dist.\n' + 'Type I error is '  + str(type_one_error_boot))
	grid(True)

	# plot null distribution spectrum
	subplot(2,3,5)
	grid(True)
	gca().xaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
	gca().yaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
	hist(null_samples_spectrum, 20, range=hist_range, normed=True);
	axvline(thresh_spectrum, 0, 1, linewidth=2, color='red')
	title('Null Dist. Spectrum\nType I error is '  + str(type_one_error_spectrum))

	# plot null distribution gamma
	subplot(2,3,6)
	grid(True)
	gca().xaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
	gca().yaxis.set_major_locator( MaxNLocator(nbins = 3) ) # reduce number of x-ticks
	hist(null_samples_gamma, 20, range=hist_range, normed=True);
	axvline(thresh_gamma, 0, 1, linewidth=2, color='red')
	title('Null Dist. Gamma\nType I error is '  + str(type_one_error_gamma))

	# pull plots a bit apart
	subplots_adjust(hspace=0.5)
	subplots_adjust(wspace=0.5)
def quadratic_time_mmd_graphical():

    # parameters, change to get different results
    m = 100
    dim = 2

    # setting the difference of the first dimension smaller makes a harder test
    difference = 0.5

    # number of samples taken from null and alternative distribution
    num_null_samples = 500

    # streaming data generator for mean shift distributions
    gen_p = MeanShiftDataGenerator(0, dim)
    gen_q = MeanShiftDataGenerator(difference, dim)

    # Stream examples and merge them in order to compute MMD on joint sample
    # alternative is to call a different constructor of QuadraticTimeMMD
    features = gen_p.get_streamed_features(m)
    features = features.create_merged_copy(gen_q.get_streamed_features(m))

    # use the median kernel selection
    # create combined kernel with Gaussian kernels inside (shoguns Gaussian kernel is
    # compute median data distance in order to use for Gaussian kernel width
    # 0.5*median_distance normally (factor two in Gaussian kernel)
    # However, shoguns kernel width is different to usual parametrization
    # Therefore 0.5*2*median_distance^2
    # Use a subset of data for that, only 200 elements. Median is stable
    sigmas = [2**x for x in range(-3, 10)]
    widths = [x * x * 2 for x in sigmas]
    print "kernel widths:", widths
    combined = CombinedKernel()
    for i in range(len(sigmas)):
        combined.append_kernel(GaussianKernel(10, widths[i]))

    # create MMD instance, use biased statistic
    mmd = QuadraticTimeMMD(combined, features, m)
    mmd.set_statistic_type(BIASED)

    # kernel selection instance (this can easily replaced by the other methods for selecting
    # single kernels
    selection = MMDKernelSelectionMax(mmd)

    # perform kernel selection
    kernel = selection.select_kernel()
    kernel = GaussianKernel.obtain_from_generic(kernel)
    mmd.set_kernel(kernel)
    print "selected kernel width:", kernel.get_width()

    # sample alternative distribution (new data each trial)
    alt_samples = zeros(num_null_samples)
    for i in range(len(alt_samples)):
        # Stream examples and merge them in order to replace in MMD
        features = gen_p.get_streamed_features(m)
        features = features.create_merged_copy(gen_q.get_streamed_features(m))
        mmd.set_p_and_q(features)
        alt_samples[i] = mmd.compute_statistic()

    # sample from null distribution
    # bootstrapping, biased statistic
    mmd.set_null_approximation_method(BOOTSTRAP)
    mmd.set_statistic_type(BIASED)
    mmd.set_bootstrap_iterations(num_null_samples)
    null_samples_boot = mmd.bootstrap_null()

    # sample from null distribution
    # spectrum, biased statistic
    if "sample_null_spectrum" in dir(QuadraticTimeMMD):
        mmd.set_null_approximation_method(MMD2_SPECTRUM)
        mmd.set_statistic_type(BIASED)
        null_samples_spectrum = mmd.sample_null_spectrum(
            num_null_samples, m - 10)

    # fit gamma distribution, biased statistic
    mmd.set_null_approximation_method(MMD2_GAMMA)
    mmd.set_statistic_type(BIASED)
    gamma_params = mmd.fit_null_gamma()
    # sample gamma with parameters
    null_samples_gamma = array([
        gamma(gamma_params[0], gamma_params[1])
        for _ in range(num_null_samples)
    ])

    # to plot data, sample a few examples from stream first
    features = gen_p.get_streamed_features(m)
    features = features.create_merged_copy(gen_q.get_streamed_features(m))
    data = features.get_feature_matrix()

    # plot
    figure()
    title('Quadratic Time MMD')

    # plot data of p and q
    subplot(2, 3, 1)
    grid(True)
    gca().xaxis.set_major_locator(
        MaxNLocator(nbins=4))  # reduce number of x-ticks
    gca().yaxis.set_major_locator(
        MaxNLocator(nbins=4))  # reduce number of x-ticks
    plot(data[0][0:m], data[1][0:m], 'ro', label='$x$')
    plot(data[0][m + 1:2 * m],
         data[1][m + 1:2 * m],
         'bo',
         label='$x$',
         alpha=0.5)
    title('Data, shift in $x_1$=' + str(difference) + '\nm=' + str(m))
    xlabel('$x_1, y_1$')
    ylabel('$x_2, y_2$')

    # histogram of first data dimension and pdf
    subplot(2, 3, 2)
    grid(True)
    gca().xaxis.set_major_locator(
        MaxNLocator(nbins=3))  # reduce number of x-ticks
    gca().yaxis.set_major_locator(
        MaxNLocator(nbins=3))  # reduce number of x-ticks
    hist(data[0], bins=50, alpha=0.5, facecolor='r', normed=True)
    hist(data[1], bins=50, alpha=0.5, facecolor='b', normed=True)
    xs = linspace(min(data[0]) - 1, max(data[0]) + 1, 50)
    plot(xs, normpdf(xs, 0, 1), 'r', linewidth=3)
    plot(xs, normpdf(xs, difference, 1), 'b', linewidth=3)
    xlabel('$x_1, y_1$')
    ylabel('$p(x_1), p(y_1)$')
    title('Data PDF in $x_1, y_1$')

    # compute threshold for test level
    alpha = 0.05
    null_samples_boot.sort()
    null_samples_spectrum.sort()
    null_samples_gamma.sort()
    thresh_boot = null_samples_boot[floor(
        len(null_samples_boot) * (1 - alpha))]
    thresh_spectrum = null_samples_spectrum[floor(
        len(null_samples_spectrum) * (1 - alpha))]
    thresh_gamma = null_samples_gamma[floor(
        len(null_samples_gamma) * (1 - alpha))]

    type_one_error_boot = sum(
        null_samples_boot < thresh_boot) / float(num_null_samples)
    type_one_error_spectrum = sum(
        null_samples_spectrum < thresh_boot) / float(num_null_samples)
    type_one_error_gamma = sum(
        null_samples_gamma < thresh_boot) / float(num_null_samples)

    # plot alternative distribution with threshold
    subplot(2, 3, 4)
    grid(True)
    gca().xaxis.set_major_locator(
        MaxNLocator(nbins=3))  # reduce number of x-ticks
    gca().yaxis.set_major_locator(
        MaxNLocator(nbins=3))  # reduce number of x-ticks
    hist(alt_samples, 20, normed=True)
    axvline(thresh_boot, 0, 1, linewidth=2, color='red')
    type_two_error = sum(alt_samples < thresh_boot) / float(num_null_samples)
    title('Alternative Dist.\n' + 'Type II error is ' + str(type_two_error))

    # compute range for all null distribution histograms
    hist_range = [
        min([
            min(null_samples_boot),
            min(null_samples_spectrum),
            min(null_samples_gamma)
        ]),
        max([
            max(null_samples_boot),
            max(null_samples_spectrum),
            max(null_samples_gamma)
        ])
    ]

    # plot null distribution with threshold
    subplot(2, 3, 3)
    gca().xaxis.set_major_locator(
        MaxNLocator(nbins=3))  # reduce number of x-ticks
    gca().yaxis.set_major_locator(
        MaxNLocator(nbins=3))  # reduce number of x-ticks
    hist(null_samples_boot, 20, range=hist_range, normed=True)
    axvline(thresh_boot, 0, 1, linewidth=2, color='red')
    title('Bootstrapped Null Dist.\n' + 'Type I error is ' +
          str(type_one_error_boot))
    grid(True)

    # plot null distribution spectrum
    subplot(2, 3, 5)
    grid(True)
    gca().xaxis.set_major_locator(
        MaxNLocator(nbins=3))  # reduce number of x-ticks
    gca().yaxis.set_major_locator(
        MaxNLocator(nbins=3))  # reduce number of x-ticks
    hist(null_samples_spectrum, 20, range=hist_range, normed=True)
    axvline(thresh_spectrum, 0, 1, linewidth=2, color='red')
    title('Null Dist. Spectrum\nType I error is ' +
          str(type_one_error_spectrum))

    # plot null distribution gamma
    subplot(2, 3, 6)
    grid(True)
    gca().xaxis.set_major_locator(
        MaxNLocator(nbins=3))  # reduce number of x-ticks
    gca().yaxis.set_major_locator(
        MaxNLocator(nbins=3))  # reduce number of x-ticks
    hist(null_samples_gamma, 20, range=hist_range, normed=True)
    axvline(thresh_gamma, 0, 1, linewidth=2, color='red')
    title('Null Dist. Gamma\nType I error is ' + str(type_one_error_gamma))

    # pull plots a bit apart
    subplots_adjust(hspace=0.5)
    subplots_adjust(wspace=0.5)
def compare_against_mmd_test():
    data = loadmat("../data/02-solar.mat")
    X = data["X"]
    y = data["y"]

    X_train, y_train, X_test, y_test, N, N_test = prepare_dataset(X, y)

    kernel = RBF(input_dim=1, variance=0.608, lengthscale=0.207)
    m = GPRegression(X_train, y_train, kernel, noise_var=0.283)
    m.optimize()
    pred_mean, pred_std = m.predict(X_test)

    s = GaussianQuadraticTest(None)
    gradients = compute_gp_regression_gradients(y_test, pred_mean, pred_std)
    U_matrix, stat = s.get_statistic_multiple_custom_gradient(y_test[:, 0], gradients[:, 0])
    num_test_samples = 10000
    null_samples = bootstrap_null(U_matrix, num_bootstrap=num_test_samples)
    #     null_samples = sample_null_simulated_gp(s, pred_mean, pred_std, num_test_samples)
    p_value_ours = 1.0 - np.mean(null_samples <= stat)

    y_rep = np.random.randn(len(X_test)) * pred_std.flatten() + pred_mean.flatten()
    y_rep = np.atleast_2d(y_rep).T
    A = np.hstack((X_test, y_test))
    B = np.hstack((X_test, y_rep))
    feats_p = RealFeatures(A.T)
    feats_q = RealFeatures(B.T)
    width = 1
    kernel = GaussianKernel(10, width)
    mmd = QuadraticTimeMMD()
    mmd.set_kernel(kernel)
    mmd.set_p(feats_p)
    mmd.set_q(feats_q)
    mmd_stat = mmd.compute_statistic()

    # sample from null
    num_null_samples = 10000
    mmd_null_samples = np.zeros(num_null_samples)
    for i in range(num_null_samples):
        # fix y_rep from above, and change the other one (that would replace y_test)
        y_rep2 = np.random.randn(len(X_test)) * pred_std.flatten() + pred_mean.flatten()
        y_rep2 = np.atleast_2d(y_rep2).T
        A = np.hstack((X_test, y_rep2))
        feats_p = RealFeatures(A.T)
        width = 1
        kernel = GaussianKernel(10, width)
        mmd = QuadraticTimeMMD()
        mmd.set_kernel(kernel)
        mmd.set_p(feats_p)
        mmd.set_q(feats_q)
        mmd_null_samples[i] = mmd.compute_statistic()

    p_value_mmd = 1.0 - np.mean(mmd_null_samples <= mmd_stat)

    return p_value_ours, p_value_mmd
    # compare to Lloyd & Gharamani
    # sample from GP, and perform MMD two sample test between test data and sampled data
    y_rep = np.random.randn(len(X_test)) * pred_std.flatten() + pred_mean.flatten()
    y_rep = np.atleast_2d(y_rep).T

    # stack together (X_test,y_test) and (X_test, y_pred)
    A = np.hstack((X_test, y_test))
    B = np.hstack((X_test, y_rep))

    # compute MMD between (X_test,y_test) and (X_test, y_pred)
    feats_p = RealFeatures(A.T)
    feats_q = RealFeatures(B.T)
    width = 1
    kernel = GaussianKernel(10, width)
    mmd = QuadraticTimeMMD()
    mmd.set_kernel(kernel)
    mmd.set_p(feats_p)
    mmd.set_q(feats_q)
    mmd_stat = mmd.compute_statistic()

    # sample from null
    num_null_samples = 10000
    mmd_null_samples = np.zeros(num_null_samples)
    for i in range(num_null_samples):
        # fix y_rep from above, and change the other one (that would replace y_test)
        y_rep2 = np.random.randn(len(X_test)) * pred_std.flatten() + pred_mean.flatten()
        y_rep2 = np.atleast_2d(y_rep2).T

        A = np.hstack((X_test, y_rep2))