コード例 #1
0
ファイル: density.py プロジェクト: sweatyrichard/kernel-gof
 def get_datasource(self):
     return data.DSNormal(self.mean, self.cov)
コード例 #2
0
def get_pqsource(prob_label):
    """
    Return (p, ds), a tuple of
    - p: a Density representing the distribution p
    - ds: a DataSource, each corresponding to one parameter setting.
        The DataSource generates sample from q.
    """
    prob2tuples = {
        # H0 is true. vary d. P = Q = N(0, I)
        'sg5':
        (density.IsotropicNormal(np.zeros(5),
                                 1), data.DSIsotropicNormal(np.zeros(5), 1)),

        # P = N(0, I), Q = N( (0.2,..0), I)
        'gmd5': (density.IsotropicNormal(np.zeros(5), 1),
                 data.DSIsotropicNormal(np.hstack((0.2, np.zeros(4))), 1)),
        'gmd1': (density.IsotropicNormal(np.zeros(1), 1),
                 data.DSIsotropicNormal(np.ones(1) * 0.2, 1)),

        # P = N(0, I), Q = N( (1,..0), I)
        'gmd100': (density.IsotropicNormal(np.zeros(100), 1),
                   data.DSIsotropicNormal(np.hstack((1, np.zeros(99))), 1)),

        # Gaussian variance difference problem. Only the variance
        # of the first dimenion differs. d varies.
        'gvd5': (density.Normal(np.zeros(5), np.eye(5)),
                 data.DSNormal(np.zeros(5), np.diag(np.hstack(
                     (2, np.ones(4)))))),
        'gvd10': (density.Normal(np.zeros(10), np.eye(10)),
                  data.DSNormal(np.zeros(10),
                                np.diag(np.hstack((2, np.ones(9)))))),

        # Gaussian Bernoulli RBM. dx=50, dh=10. H0 is true
        'gbrbm_dx50_dh10_v0':
        gaussbern_rbm_tuple(0, dx=50, dh=10, n=sample_size),

        # Gaussian Bernoulli RBM. dx=5, dh=3. H0 is true
        'gbrbm_dx5_dh3_v0':
        gaussbern_rbm_tuple(0, dx=5, dh=3, n=sample_size),

        # Gaussian Bernoulli RBM. dx=50, dh=10.
        'gbrbm_dx50_dh10_v1em3':
        gaussbern_rbm_tuple(1e-3, dx=50, dh=10, n=sample_size),

        # Gaussian Bernoulli RBM. dx=5, dh=3. Perturb with noise = 1e-2.
        'gbrbm_dx5_dh3_v5em3':
        gaussbern_rbm_tuple(5e-3, dx=5, dh=3, n=sample_size),

        # Gaussian mixture of two components. Uniform mixture weights.
        # p = 0.5*N(0, 1) + 0.5*N(3, 0.01)
        # q = 0.5*N(-3, 0.01) + 0.5*N(0, 1)
        'gmm_d1': (density.IsoGaussianMixture(np.array([[0], [3.0]]),
                                              np.array([1, 0.01])),
                   data.DSIsoGaussianMixture(np.array([[-3.0], [0]]),
                                             np.array([0.01, 1]))),

        # p = N(0, 1)
        # q = 0.1*N([-10, 0,..0], 0.001) + 0.9*N([0,0,..0], 1)
        'g_vs_gmm_d5': (density.IsotropicNormal(np.zeros(5), 1),
                        data.DSIsoGaussianMixture(np.vstack((np.hstack(
                            (0.0, np.zeros(4))), np.zeros(5))),
                                                  np.array([0.0001, 1]),
                                                  pmix=[0.1, 0.9])),
        'g_vs_gmm_d2': (density.IsotropicNormal(np.zeros(2), 1),
                        data.DSIsoGaussianMixture(np.vstack((np.hstack(
                            (0.0, np.zeros(1))), np.zeros(2))),
                                                  np.array([0.01, 1]),
                                                  pmix=[0.1, 0.9])),
        'g_vs_gmm_d1': (density.IsotropicNormal(np.zeros(1), 1),
                        data.DSIsoGaussianMixture(np.array([[0.0], [0]]),
                                                  np.array([0.01, 1]),
                                                  pmix=[0.1, 0.9])),
    }
    if prob_label not in prob2tuples:
        raise ValueError('Unknown problem label. Need to be one of %s' %
                         str(prob2tuples.keys()))
    return prob2tuples[prob_label]
コード例 #3
0
def get_pqsource_list(prob_label):
    """
    Return [(prob_param, p, ds) for ... ], a list of tuples
    where 
    - prob_param: a problem parameters. Each parameter has to be a
      scalar (so that we can plot them later). Parameters are preferably
      positive integers.
    - p: a Density representing the distribution p
    - ds: a DataSource, each corresponding to one parameter setting.
        The DataSource generates sample from q.
    """
    sg_ds = [1, 5, 10, 15]
    gmd_ds = [5, 20, 40, 60]
    # vary the mean
    gmd_d10_ms = [0, 0.02, 0.04, 0.06]
    gvinc_d1_vs = [1, 1.5, 2, 2.5]
    gvinc_d5_vs = [1, 1.5, 2, 2.5]
    gvsub1_d1_vs = [0.1, 0.3, 0.5, 0.7]
    gvd_ds = [1, 5, 10, 15]

    #gb_rbm_dx50_dh10_stds = [0, 0.01, 0.02, 0.03]
    gb_rbm_dx50_dh10_stds = [0, 0.02, 0.04, 0.06]
    #gb_rbm_dx50_dh10_stds = [0]
    gb_rbm_dx50_dh40_stds = [0, 0.01, 0.02, 0.04, 0.06]
    glaplace_ds = [1, 5, 10, 15]
    prob2tuples = {
        # H0 is true. vary d. P = Q = N(0, I)
        'sg': [(d, density.IsotropicNormal(np.zeros(d), 1),
                data.DSIsotropicNormal(np.zeros(d), 1)) for d in sg_ds],

        # vary d. P = N(0, I), Q = N( (c,..0), I)
        'gmd': [(d, density.IsotropicNormal(np.zeros(d), 1),
                 data.DSIsotropicNormal(np.hstack((1, np.zeros(d - 1))), 1))
                for d in gmd_ds],
        # P = N(0, I), Q = N( (m, ..0), I). Vary m
        'gmd_d10_ms': [(m, density.IsotropicNormal(np.zeros(10), 1),
                        data.DSIsotropicNormal(np.hstack((m, np.zeros(9))), 1))
                       for m in gmd_d10_ms],
        # d=1. Increase the variance. P = N(0, I). Q = N(0, v*I)
        'gvinc_d1': [(var, density.IsotropicNormal(np.zeros(1), 1),
                      data.DSIsotropicNormal(np.zeros(1), var))
                     for var in gvinc_d1_vs],
        # d=5. Increase the variance. P = N(0, I). Q = N(0, v*I)
        'gvinc_d5': [(var, density.IsotropicNormal(np.zeros(5), 1),
                      data.DSIsotropicNormal(np.zeros(5), var))
                     for var in gvinc_d5_vs],
        # d=1. P=N(0,1), Q(0,v). Consider the variance below 1.
        'gvsub1_d1': [(var, density.IsotropicNormal(np.zeros(1), 1),
                       data.DSIsotropicNormal(np.zeros(1), var))
                      for var in gvsub1_d1_vs],
        # Gaussian variance difference problem. Only the variance
        # of the first dimenion differs. d varies.
        'gvd': [(d, density.Normal(np.zeros(d), np.eye(d)),
                 data.DSNormal(np.zeros(d),
                               np.diag(np.hstack((2, np.ones(d - 1))))))
                for d in gvd_ds],

        # Gaussian Bernoulli RBM. dx=50, dh=10
        'gbrbm_dx50_dh10':
        gaussbern_rbm_probs(gb_rbm_dx50_dh10_stds, dx=50, dh=10,
                            n=sample_size),

        # Gaussian Bernoulli RBM. dx=50, dh=40
        'gbrbm_dx50_dh40':
        gaussbern_rbm_probs(gb_rbm_dx50_dh40_stds, dx=50, dh=40,
                            n=sample_size),

        # p: N(0, I), q: standard Laplace. Vary d
        'glaplace': [
            (
                d,
                density.IsotropicNormal(np.zeros(d), 1),
                # Scaling of 1/sqrt(2) will make the variance 1.
                data.DSLaplace(d=d, loc=0, scale=1.0 / np.sqrt(2)))
            for d in glaplace_ds
        ],
    }
    if prob_label not in prob2tuples:
        raise ValueError('Unknown problem label. Need to be one of %s' %
                         str(prob2tuples.keys()))
    return prob2tuples[prob_label]
コード例 #4
0
ファイル: lsd_test.py プロジェクト: stjordanis/LSD-1
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--test', choices=['gaussian-laplace', 'laplace-gaussian',
                                           'gaussian-pert', 'rbm-pert', 'rbm-pert1'], type=str)
    parser.add_argument('--dim_x', type=int, default=50)
    parser.add_argument('--dim_h', type=int, default=40)
    parser.add_argument('--sigma_pert', type=float, default=.02)
    parser.add_argument('--maximize_power', action="store_true")
    parser.add_argument('--maximize_adj_mean', action="store_true")
    parser.add_argument('--val_power', action="store_true")
    parser.add_argument('--val_adj_mean', action="store_true")
    parser.add_argument('--dropout', action="store_true")
    parser.add_argument('--alpha', type=float, default=.05)
    parser.add_argument('--save', type=str, default='/tmp/test_ksd')

    parser.add_argument('--test_type', type=str, default='mine')



    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--l2', type=float, default=0.)

    parser.add_argument('--num_const', type=float, default=1e-6)

    parser.add_argument('--log_freq', type=int, default=10)
    parser.add_argument('--val_freq', type=int, default=100)
    parser.add_argument('--weight_decay', type=float, default=0)



    parser.add_argument('--seed', type=int, default=100001)
    parser.add_argument('--n_train', type=int, default=1000)
    parser.add_argument('--n_val', type=int, default=1000)
    parser.add_argument('--n_test', type=int, default=1000)
    parser.add_argument('--n_iters', type=int, default=100001)
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--test_batch_size', type=int, default=1000)
    parser.add_argument('--test_burn_in', type=int, default=0)
    parser.add_argument('--mode', type=str, default="fs")
    parser.add_argument('--viz_freq', type=int, default=100)
    parser.add_argument('--save_freq', type=int, default=10000)

    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--base_dist', action="store_true")
    parser.add_argument('--t_iters', type=int, default=5)
    parser.add_argument('--k_dim', type=int, default=1)
    parser.add_argument('--sn', type=float, default=-1.)

    parser.add_argument('--exact_trace', action="store_true")
    parser.add_argument('--quadratic', action="store_true")
    parser.add_argument('--n_steps', type=int, default=100)
    parser.add_argument('--both_scaled', action="store_true")
    args = parser.parse_args()
    device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    if args.test == "gaussian-laplace":
        mu = torch.zeros((args.dim_x,))
        std = torch.ones((args.dim_x,))
        p_dist = Gaussian(mu, std)
        q_dist = Laplace(mu, std)
    elif args.test == "laplace-gaussian":
        mu = torch.zeros((args.dim_x,))
        std = torch.ones((args.dim_x,))
        q_dist = Gaussian(mu, std)
        p_dist = Laplace(mu, std / (2 ** .5))
    elif args.test == "gaussian-pert":
        mu = torch.zeros((args.dim_x,))
        std = torch.ones((args.dim_x,))
        p_dist = Gaussian(mu, std)
        q_dist = Gaussian(mu + torch.randn_like(mu) * args.sigma_pert, std)
    elif args.test == "rbm-pert1":
        B = randb((args.dim_x, args.dim_h)) * 2. - 1.
        c = torch.randn((1, args.dim_h))
        b = torch.randn((1, args.dim_x))

        p_dist = GaussianBernoulliRBM(B, b, c)
        B2 = B.clone()
        B2[0, 0] += torch.randn_like(B2[0, 0]) * args.sigma_pert
        q_dist = GaussianBernoulliRBM(B2, b, c)
    else:  # args.test == "rbm-pert"
        B = randb((args.dim_x, args.dim_h)) * 2. - 1.
        c = torch.randn((1, args.dim_h))
        b = torch.randn((1, args.dim_x))

        p_dist = GaussianBernoulliRBM(B, b, c)
        q_dist = GaussianBernoulliRBM(B + torch.randn_like(B) * args.sigma_pert, b, c)

    # run mah shiiiiit
    if args.test_type == "mine":
        import numpy as np
        data = p_dist.sample(args.n_train + args.n_val + args.n_test).detach()
        data_train = data[:args.n_train]
        data_rest = data[args.n_train:]
        data_val = data_rest[:args.n_val].requires_grad_()
        data_test = data_rest[args.n_val:].requires_grad_()
        assert data_test.size(0) == args.n_test

        critic = networks.SmallMLP(args.dim_x, n_out=args.dim_x, n_hid=300, dropout=args.dropout)
        optimizer = optim.Adam(critic.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        def stein_discrepency(x, exact=False):
            if "rbm" in args.test:
                sq = q_dist.score_function(x)
            else:
                logq_u = q_dist(x)
                sq = keep_grad(logq_u.sum(), x)
            fx = critic(x)
            if args.dim_x == 1:
                fx = fx[:, None]
            sq_fx = (sq * fx).sum(-1)

            if exact:
                tr_dfdx = exact_jacobian_trace(fx, x)
            else:
                tr_dfdx = approx_jacobian_trace(fx, x)

            norms = (fx * fx).sum(1)
            stats = (sq_fx + tr_dfdx)
            return stats, norms

        # training phase
        best_val = -np.inf
        validation_metrics = []
        test_statistics = []
        critic.train()
        for itr in range(args.n_iters):
            optimizer.zero_grad()
            x = sample_batch(data_train, args.batch_size)
            x = x.to(device)
            x.requires_grad_()

            stats, norms = stein_discrepency(x)
            mean, std = stats.mean(), stats.std()
            l2_penalty = norms.mean() * args.l2


            if args.maximize_power:
                loss = -1. * mean / (std + args.num_const) + l2_penalty
            elif args.maximize_adj_mean:
                loss = -1. * mean + std + l2_penalty
            else:
                loss = -1. * mean + l2_penalty

            loss.backward()
            optimizer.step()

            if itr % args.log_freq == 0:
                print("Iter {}, Loss = {}, Mean = {}, STD = {}, L2 {}".format(itr,
                                                                           loss.item(), mean.item(), std.item(),
                                                                           l2_penalty.item()))

            if itr % args.val_freq == 0:
                critic.eval()
                val_stats, _ = stein_discrepency(data_val, exact=True)
                test_stats, _ = stein_discrepency(data_test, exact=True)
                print("Val: {} +/- {}".format(val_stats.mean().item(), val_stats.std().item()))
                print("Test: {} +/- {}".format(test_stats.mean().item(), test_stats.std().item()))

                if args.val_power:
                    validation_metric = val_stats.mean() / (val_stats.std() + args.num_const)
                elif args.val_adj_mean:
                    validation_metric = val_stats.mean() - val_stats.std()
                else:
                    validation_metric = val_stats.mean()

                test_statistic = test_stats.mean() / (test_stats.std() + args.num_const)



                if validation_metric > best_val:
                    print("Iter {}, Validation Metric = {} > {}, Test Statistic = {}, Current Best!".format(itr,
                                                                                                  validation_metric.item(),
                                                                                                  best_val,
                                                                                                  test_statistic.item()))
                    best_val = validation_metric.item()
                else:
                    print("Iter {}, Validation Metric = {}, Test Statistic = {}, Not best {}".format(itr,
                                                                                                validation_metric.item(),
                                                                                                test_statistic.item(),
                                                                                                best_val))
                validation_metrics.append(validation_metric.item())
                test_statistics.append(test_statistic)
                critic.train()
        best_ind = np.argmax(validation_metrics)
        best_test = test_statistics[best_ind]

        print("Best val is {}, best test is {}".format(best_val, best_test))
        test_stat = best_test * args.n_test ** .5
        threshold = distributions.Normal(0, 1).icdf(torch.ones((1,)) * (1. - args.alpha)).item()
        try_make_dirs(os.path.dirname(args.save))
        with open(args.save, 'w') as f:
            f.write(str(test_stat) + '\n')
            if test_stat > threshold:
                print("{} > {}, rejct Null".format(test_stat, threshold))
                f.write("reject")
            else:
                print("{} <= {}, accept Null".format(test_stat, threshold))
                f.write("accept")

    # baselines
    else:
        import autograd.numpy as np
        #import kgof.goftest as gof
        import mygoftest as gof
        import kgof.util as util
        import kgof.kernel as kernel
        import kgof.density as density
        import kgof.data as kdata

        class GaussBernRBM(density.UnnormalizedDensity):
            """
            Gaussian-Bernoulli Restricted Boltzmann Machine.
            The joint density takes the form
                p(x, h) = Z^{-1} exp(0.5*x^T B h + b^T x + c^T h - 0.5||x||^2)
            where h is a vector of {-1, 1}.
            """

            def __init__(self, B, b, c):
                """
                B: a dx x dh matrix
                b: a numpy array of length dx
                c: a numpy array of length dh
                """
                dh = len(c)
                dx = len(b)
                assert B.shape[0] == dx
                assert B.shape[1] == dh
                assert dx > 0
                assert dh > 0
                self.B = B
                self.b = b
                self.c = c

            def log_den(self, X):
                B = self.B
                b = self.b
                c = self.c

                XBC = 0.5 * np.dot(X, B) + c
                unden = np.dot(X, b) - 0.5 * np.sum(X ** 2, 1) + np.sum(np.log(np.exp(XBC)
                                                                               + np.exp(-XBC)), 1)
                assert len(unden) == X.shape[0]
                return unden

            def grad_log(self, X):
                #    """
                #    Evaluate the gradients (with respect to the input) of the log density at
                #    each of the n points in X. This is the score function.

                #    X: n x d numpy array.
                """
                Evaluate the gradients (with respect to the input) of the log density at
                each of the n points in X. This is the score function.

                X: n x d numpy array.

                Return an n x d numpy array of gradients.
                """
                XB = np.dot(X, self.B)
                Y = 0.5 * XB + self.c
                # E2y = np.exp(2*Y)
                # n x dh
                # Phi = old_div((E2y-1.0),(E2y+1))
                Phi = np.tanh(Y)
                # n x dx
                T = np.dot(Phi, 0.5 * self.B.T)
                S = self.b - X + T
                return S

            def get_datasource(self, burnin=2000):
                return data.DSGaussBernRBM(self.B, self.b, self.c, burnin=burnin)

            def dim(self):
                return len(self.b)

        def job_lin_kstein_med(p, data_source, tr, te, r):
            """
            Linear-time version of the kernel Stein discrepancy test of Liu et al.,
            2016 and Chwialkowski et al., 2016. Use full sample.
            """
            # full data
            data = tr + te
            X = data.data()
            with util.ContextTimer() as t:
                # median heuristic
                med = util.meddistance(X, subsample=1000)
                k = kernel.KGauss(med ** 2)

                lin_kstein = gof.LinearKernelSteinTest(p, k, alpha=args.alpha, seed=r)
                lin_kstein_result = lin_kstein.perform_test(data)
            return {'test_result': lin_kstein_result, 'time_secs': t.secs}

        def job_mmd_opt(p, data_source, tr, te, r, model_sample):
            # full data
            data = tr + te
            X = data.data()
            with util.ContextTimer() as t:
                mmd = gof.QuadMMDGofOpt(p, alpha=args.alpha, seed=r)
                mmd_result = mmd.perform_test(data, model_sample)
            return {'test_result': mmd_result, 'time_secs': t.secs}


        def job_kstein_med(p, data_source, tr, te, r):
            """
            Kernel Stein discrepancy test of Liu et al., 2016 and Chwialkowski et al.,
            2016. Use full sample. Use Gaussian kernel.
            """
            # full data
            data = tr + te
            X = data.data()
            with util.ContextTimer() as t:
                # median heuristic
                med = util.meddistance(X, subsample=1000)
                k = kernel.KGauss(med ** 2)

                kstein = gof.KernelSteinTest(p, k, alpha=args.alpha, n_simulate=1000, seed=r)
                kstein_result = kstein.perform_test(data)
            return {'test_result': kstein_result, 'time_secs': t.secs}

        def job_fssdJ1q_opt(p, data_source, tr, te, r, J=1, null_sim=None):
            """
            FSSD with optimization on tr. Test on te. Use a Gaussian kernel.
            """
            if null_sim is None:
                null_sim = gof.FSSDH0SimCovObs(n_simulate=2000, seed=r)

            Xtr = tr.data()
            with util.ContextTimer() as t:
                # Use grid search to initialize the gwidth
                n_gwidth_cand = 5
                gwidth_factors = 2.0 ** np.linspace(-3, 3, n_gwidth_cand)
                med2 = util.meddistance(Xtr, 1000) ** 2
                print(med2)
                k = kernel.KGauss(med2 * 2)
                # fit a Gaussian to the data and draw to initialize V0
                V0 = util.fit_gaussian_draw(Xtr, J, seed=r + 1, reg=1e-6)
                list_gwidth = np.hstack(((med2) * gwidth_factors))
                besti, objs = gof.GaussFSSD.grid_search_gwidth(p, tr, V0, list_gwidth)
                gwidth = list_gwidth[besti]
                assert util.is_real_num(gwidth), 'gwidth not real. Was %s' % str(gwidth)
                assert gwidth > 0, 'gwidth not positive. Was %.3g' % gwidth
                print('After grid search, gwidth=%.3g' % gwidth)

                ops = {
                    'reg': 1e-2,
                    'max_iter': 40,
                    'tol_fun': 1e-4,
                    'disp': True,
                    'locs_bounds_frac': 10.0,
                    'gwidth_lb': 1e-1,
                    'gwidth_ub': 1e4,
                }

                V_opt, gwidth_opt, info = gof.GaussFSSD.optimize_locs_widths(p, tr,
                                                                             gwidth, V0, **ops)
                # Use the optimized parameters to construct a test
                k_opt = kernel.KGauss(gwidth_opt)
                fssd_opt = gof.FSSD(p, k_opt, V_opt, null_sim=null_sim, alpha=args.alpha)
                fssd_opt_result = fssd_opt.perform_test(te)
            return {'test_result': fssd_opt_result, 'time_secs': t.secs,
                    'goftest': fssd_opt, 'opt_info': info,
                    }

        def job_fssdJ5q_opt(p, data_source, tr, te, r):
            return job_fssdJ1q_opt(p, data_source, tr, te, r, J=5)


        if "rbm" in args.test:
            if args.test_type == "mmd":
                q = kdata.DSGaussBernRBM(np.array(q_dist.B.detach().numpy()),
                                         np.array(q_dist.b.detach().numpy()[0]),
                                         np.array(q_dist.c.detach().numpy()[0]))
            else:
                q = GaussBernRBM(np.array(q_dist.B.detach().numpy()),
                                         np.array(q_dist.b.detach().numpy()[0]),
                                         np.array(q_dist.c.detach().numpy()[0]))
            p = kdata.DSGaussBernRBM(np.array(p_dist.B.detach().numpy()),
                                     np.array(p_dist.b.detach().numpy()[0]),
                                     np.array(p_dist.c.detach().numpy()[0]))
        elif args.test == "laplace-gaussian":
            mu = np.zeros((args.dim_x,))
            std = np.eye(args.dim_x)
            q = density.Normal(mu, std)
            p = kdata.DSLaplace(args.dim_x, scale=1/(2. ** .5))
        elif args.test == "gaussian-pert":
            mu = np.zeros((args.dim_x,))
            std = np.eye(args.dim_x)
            q = density.Normal(mu, std)
            p = kdata.DSNormal(mu, std)

        data_train = p.sample(args.n_train, args.seed)
        data_test = p.sample(args.n_test, args.seed + 1)


        if args.test_type == "fssd":
            result = job_fssdJ5q_opt(q, p, data_train, data_test, r=args.seed)
        elif args.test_type == "ksd":
            result = job_kstein_med(q, p, data_train, data_test, r=args.seed)
        elif args.test_type == "lksd":
            result = job_lin_kstein_med(q, p, data_train, data_test, r=args.seed)
        elif args.test_type == "mmd":
            model_sample = q.sample(args.n_train + args.n_test, args.seed + 2)
            result = job_mmd_opt(q, p, data_train, data_test, args.seed, model_sample)
        print(result['test_result'])
        reject = result['test_result']['h0_rejected']
        try_make_dirs(os.path.dirname(args.save))
        with open(args.save, 'w') as f:
            if reject:
                print("reject")
                f.write("reject")
            else:
                print("accept")
                f.write("accept")