Beispiel #1
0
def main(args):
    utils.makedirs(args.save)
    net = networks.SmallMLP(2, 2, n_hid=args.hid)

    if args.dataset == "moons":
        Xf, Y = datasets.make_moons(1000, noise=.1)
        Xfte, Yte = datasets.make_moons(1000, noise=.1)
        Xoh, Xohte = [], []
    elif args.dataset == "circles":
        Xf, Y = datasets.make_circles(1000, noise=.03)
        Xfte, Yte = datasets.make_circles(1000, noise=.03)
        Xoh, Xohte = [], []
    elif args.dataset == "adult":
        with open("data/adult/adult.data", 'r') as f:
            Xf, Xoh, Y = data_utils.load_adult()
        with open("data/adult/adult.test", 'r') as f:
            Xfte, Xohte, Yte = data_utils.load_adult()

    else:
        raise NotImplementedError

    Xf = Xf.astype(np.float32)
    Xfl, Xohl, Yl = [], [], []
    if args.n_labels_per_class != -1:
        Xfl.extend(Xf[Y == 0][:args.n_labels_per_class])
        Xfl.extend(Xf[Y == 1][:args.n_labels_per_class])
        Yl.extend([0] * args.n_labels_per_class)
        Yl.extend([1] * args.n_labels_per_class)
        if Xoh is not None:
            Xohl.extend(Xf[Y == 0][:args.n_labels_per_class])
            Xohl.extend(Xf[Y == 1][:args.n_labels_per_class])
    else:
        Xfl, Xohl, Yl = Xf, Xoh, Y

    def plot_data(fname="data.png"):
        plt.clf()
        decision_boundary(net, Xf)
        plt.scatter(Xf[:, 0], Xf[:, 1], c='grey')
        plt.scatter(Xfl[:args.n_labels_per_class, 0],
                    Xfl[:args.n_labels_per_class, 1],
                    c='r')
        plt.scatter(Xfl[args.n_labels_per_class:, 0],
                    Xfl[args.n_labels_per_class:, 1],
                    c='b')
        plt.savefig("{}/{}".format(args.save, fname))

    optim = torch.optim.Adam(params=net.parameters(), lr=args.lr)

    xl = torch.from_numpy(Xl).to(device)
    yl = torch.from_numpy(np.array(Yl)).to(device)
    x_te, y_te = torch.from_numpy(Xte).float(), torch.from_numpy(Yte)
    inds = list(range(X.shape[0]))
    for i in range(args.n_iters):
        batch_inds = np.random.choice(inds, args.batch_size, replace=False)
        x = X[batch_inds]
        x = torch.from_numpy(x).to(device).requires_grad_()

        logits = net(xl)
        clf_loss = nn.CrossEntropyLoss(reduction='none')(logits, yl).mean()

        logits_u = net(x)
        logpx_plus_Z = logits_u.logsumexp(1)
        sp = utils.keep_grad(logpx_plus_Z.sum(), x)
        e = torch.randn_like(sp)
        eH = utils.keep_grad(sp, x, grad_outputs=e)
        trH = (eH * e).sum(-1)

        sm_loss = trH + .5 * (sp**2).sum(-1)
        sm_loss = sm_loss.mean()

        loss = (1 - args.sm_lam) * clf_loss + args.sm_lam * sm_loss
        optim.zero_grad()
        loss.backward()
        optim.step()

        if i % 100 == 0:
            if args.dataset in ("rings", "moons"):
                plot_data("data_{}.png".format(i))
            te_logits = net(x_te.float())
            te_preds = torch.argmax(te_logits, 1)
            te_acc = (te_preds == y_te).float().mean()
            print("Iter {}: Clf Loss = {}, SM Loss = {} | Test Accuracy = {}".
                  format(i, clf_loss.item(), sm_loss.item(), te_acc.item()))
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--test', choices=['gaussian-laplace', 'laplace-gaussian',
                                           'gaussian-pert', 'rbm-pert', 'rbm-pert1'], type=str)
    parser.add_argument('--dim_x', type=int, default=50)
    parser.add_argument('--dim_h', type=int, default=40)
    parser.add_argument('--sigma_pert', type=float, default=.02)
    parser.add_argument('--maximize_power', action="store_true")
    parser.add_argument('--maximize_adj_mean', action="store_true")
    parser.add_argument('--val_power', action="store_true")
    parser.add_argument('--val_adj_mean', action="store_true")
    parser.add_argument('--dropout', action="store_true")
    parser.add_argument('--alpha', type=float, default=.05)
    parser.add_argument('--save', type=str, default='/tmp/test_ksd')

    parser.add_argument('--test_type', type=str, default='mine')



    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--l2', type=float, default=0.)

    parser.add_argument('--num_const', type=float, default=1e-6)

    parser.add_argument('--log_freq', type=int, default=10)
    parser.add_argument('--val_freq', type=int, default=100)
    parser.add_argument('--weight_decay', type=float, default=0)



    parser.add_argument('--seed', type=int, default=100001)
    parser.add_argument('--n_train', type=int, default=1000)
    parser.add_argument('--n_val', type=int, default=1000)
    parser.add_argument('--n_test', type=int, default=1000)
    parser.add_argument('--n_iters', type=int, default=100001)
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--test_batch_size', type=int, default=1000)
    parser.add_argument('--test_burn_in', type=int, default=0)
    parser.add_argument('--mode', type=str, default="fs")
    parser.add_argument('--viz_freq', type=int, default=100)
    parser.add_argument('--save_freq', type=int, default=10000)

    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--base_dist', action="store_true")
    parser.add_argument('--t_iters', type=int, default=5)
    parser.add_argument('--k_dim', type=int, default=1)
    parser.add_argument('--sn', type=float, default=-1.)

    parser.add_argument('--exact_trace', action="store_true")
    parser.add_argument('--quadratic', action="store_true")
    parser.add_argument('--n_steps', type=int, default=100)
    parser.add_argument('--both_scaled', action="store_true")
    args = parser.parse_args()
    device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    if args.test == "gaussian-laplace":
        mu = torch.zeros((args.dim_x,))
        std = torch.ones((args.dim_x,))
        p_dist = Gaussian(mu, std)
        q_dist = Laplace(mu, std)
    elif args.test == "laplace-gaussian":
        mu = torch.zeros((args.dim_x,))
        std = torch.ones((args.dim_x,))
        q_dist = Gaussian(mu, std)
        p_dist = Laplace(mu, std / (2 ** .5))
    elif args.test == "gaussian-pert":
        mu = torch.zeros((args.dim_x,))
        std = torch.ones((args.dim_x,))
        p_dist = Gaussian(mu, std)
        q_dist = Gaussian(mu + torch.randn_like(mu) * args.sigma_pert, std)
    elif args.test == "rbm-pert1":
        B = randb((args.dim_x, args.dim_h)) * 2. - 1.
        c = torch.randn((1, args.dim_h))
        b = torch.randn((1, args.dim_x))

        p_dist = GaussianBernoulliRBM(B, b, c)
        B2 = B.clone()
        B2[0, 0] += torch.randn_like(B2[0, 0]) * args.sigma_pert
        q_dist = GaussianBernoulliRBM(B2, b, c)
    else:  # args.test == "rbm-pert"
        B = randb((args.dim_x, args.dim_h)) * 2. - 1.
        c = torch.randn((1, args.dim_h))
        b = torch.randn((1, args.dim_x))

        p_dist = GaussianBernoulliRBM(B, b, c)
        q_dist = GaussianBernoulliRBM(B + torch.randn_like(B) * args.sigma_pert, b, c)

    # run mah shiiiiit
    if args.test_type == "mine":
        import numpy as np
        data = p_dist.sample(args.n_train + args.n_val + args.n_test).detach()
        data_train = data[:args.n_train]
        data_rest = data[args.n_train:]
        data_val = data_rest[:args.n_val].requires_grad_()
        data_test = data_rest[args.n_val:].requires_grad_()
        assert data_test.size(0) == args.n_test

        critic = networks.SmallMLP(args.dim_x, n_out=args.dim_x, n_hid=300, dropout=args.dropout)
        optimizer = optim.Adam(critic.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        def stein_discrepency(x, exact=False):
            if "rbm" in args.test:
                sq = q_dist.score_function(x)
            else:
                logq_u = q_dist(x)
                sq = keep_grad(logq_u.sum(), x)
            fx = critic(x)
            if args.dim_x == 1:
                fx = fx[:, None]
            sq_fx = (sq * fx).sum(-1)

            if exact:
                tr_dfdx = exact_jacobian_trace(fx, x)
            else:
                tr_dfdx = approx_jacobian_trace(fx, x)

            norms = (fx * fx).sum(1)
            stats = (sq_fx + tr_dfdx)
            return stats, norms

        # training phase
        best_val = -np.inf
        validation_metrics = []
        test_statistics = []
        critic.train()
        for itr in range(args.n_iters):
            optimizer.zero_grad()
            x = sample_batch(data_train, args.batch_size)
            x = x.to(device)
            x.requires_grad_()

            stats, norms = stein_discrepency(x)
            mean, std = stats.mean(), stats.std()
            l2_penalty = norms.mean() * args.l2


            if args.maximize_power:
                loss = -1. * mean / (std + args.num_const) + l2_penalty
            elif args.maximize_adj_mean:
                loss = -1. * mean + std + l2_penalty
            else:
                loss = -1. * mean + l2_penalty

            loss.backward()
            optimizer.step()

            if itr % args.log_freq == 0:
                print("Iter {}, Loss = {}, Mean = {}, STD = {}, L2 {}".format(itr,
                                                                           loss.item(), mean.item(), std.item(),
                                                                           l2_penalty.item()))

            if itr % args.val_freq == 0:
                critic.eval()
                val_stats, _ = stein_discrepency(data_val, exact=True)
                test_stats, _ = stein_discrepency(data_test, exact=True)
                print("Val: {} +/- {}".format(val_stats.mean().item(), val_stats.std().item()))
                print("Test: {} +/- {}".format(test_stats.mean().item(), test_stats.std().item()))

                if args.val_power:
                    validation_metric = val_stats.mean() / (val_stats.std() + args.num_const)
                elif args.val_adj_mean:
                    validation_metric = val_stats.mean() - val_stats.std()
                else:
                    validation_metric = val_stats.mean()

                test_statistic = test_stats.mean() / (test_stats.std() + args.num_const)



                if validation_metric > best_val:
                    print("Iter {}, Validation Metric = {} > {}, Test Statistic = {}, Current Best!".format(itr,
                                                                                                  validation_metric.item(),
                                                                                                  best_val,
                                                                                                  test_statistic.item()))
                    best_val = validation_metric.item()
                else:
                    print("Iter {}, Validation Metric = {}, Test Statistic = {}, Not best {}".format(itr,
                                                                                                validation_metric.item(),
                                                                                                test_statistic.item(),
                                                                                                best_val))
                validation_metrics.append(validation_metric.item())
                test_statistics.append(test_statistic)
                critic.train()
        best_ind = np.argmax(validation_metrics)
        best_test = test_statistics[best_ind]

        print("Best val is {}, best test is {}".format(best_val, best_test))
        test_stat = best_test * args.n_test ** .5
        threshold = distributions.Normal(0, 1).icdf(torch.ones((1,)) * (1. - args.alpha)).item()
        try_make_dirs(os.path.dirname(args.save))
        with open(args.save, 'w') as f:
            f.write(str(test_stat) + '\n')
            if test_stat > threshold:
                print("{} > {}, rejct Null".format(test_stat, threshold))
                f.write("reject")
            else:
                print("{} <= {}, accept Null".format(test_stat, threshold))
                f.write("accept")

    # baselines
    else:
        import autograd.numpy as np
        #import kgof.goftest as gof
        import mygoftest as gof
        import kgof.util as util
        import kgof.kernel as kernel
        import kgof.density as density
        import kgof.data as kdata

        class GaussBernRBM(density.UnnormalizedDensity):
            """
            Gaussian-Bernoulli Restricted Boltzmann Machine.
            The joint density takes the form
                p(x, h) = Z^{-1} exp(0.5*x^T B h + b^T x + c^T h - 0.5||x||^2)
            where h is a vector of {-1, 1}.
            """

            def __init__(self, B, b, c):
                """
                B: a dx x dh matrix
                b: a numpy array of length dx
                c: a numpy array of length dh
                """
                dh = len(c)
                dx = len(b)
                assert B.shape[0] == dx
                assert B.shape[1] == dh
                assert dx > 0
                assert dh > 0
                self.B = B
                self.b = b
                self.c = c

            def log_den(self, X):
                B = self.B
                b = self.b
                c = self.c

                XBC = 0.5 * np.dot(X, B) + c
                unden = np.dot(X, b) - 0.5 * np.sum(X ** 2, 1) + np.sum(np.log(np.exp(XBC)
                                                                               + np.exp(-XBC)), 1)
                assert len(unden) == X.shape[0]
                return unden

            def grad_log(self, X):
                #    """
                #    Evaluate the gradients (with respect to the input) of the log density at
                #    each of the n points in X. This is the score function.

                #    X: n x d numpy array.
                """
                Evaluate the gradients (with respect to the input) of the log density at
                each of the n points in X. This is the score function.

                X: n x d numpy array.

                Return an n x d numpy array of gradients.
                """
                XB = np.dot(X, self.B)
                Y = 0.5 * XB + self.c
                # E2y = np.exp(2*Y)
                # n x dh
                # Phi = old_div((E2y-1.0),(E2y+1))
                Phi = np.tanh(Y)
                # n x dx
                T = np.dot(Phi, 0.5 * self.B.T)
                S = self.b - X + T
                return S

            def get_datasource(self, burnin=2000):
                return data.DSGaussBernRBM(self.B, self.b, self.c, burnin=burnin)

            def dim(self):
                return len(self.b)

        def job_lin_kstein_med(p, data_source, tr, te, r):
            """
            Linear-time version of the kernel Stein discrepancy test of Liu et al.,
            2016 and Chwialkowski et al., 2016. Use full sample.
            """
            # full data
            data = tr + te
            X = data.data()
            with util.ContextTimer() as t:
                # median heuristic
                med = util.meddistance(X, subsample=1000)
                k = kernel.KGauss(med ** 2)

                lin_kstein = gof.LinearKernelSteinTest(p, k, alpha=args.alpha, seed=r)
                lin_kstein_result = lin_kstein.perform_test(data)
            return {'test_result': lin_kstein_result, 'time_secs': t.secs}

        def job_mmd_opt(p, data_source, tr, te, r, model_sample):
            # full data
            data = tr + te
            X = data.data()
            with util.ContextTimer() as t:
                mmd = gof.QuadMMDGofOpt(p, alpha=args.alpha, seed=r)
                mmd_result = mmd.perform_test(data, model_sample)
            return {'test_result': mmd_result, 'time_secs': t.secs}


        def job_kstein_med(p, data_source, tr, te, r):
            """
            Kernel Stein discrepancy test of Liu et al., 2016 and Chwialkowski et al.,
            2016. Use full sample. Use Gaussian kernel.
            """
            # full data
            data = tr + te
            X = data.data()
            with util.ContextTimer() as t:
                # median heuristic
                med = util.meddistance(X, subsample=1000)
                k = kernel.KGauss(med ** 2)

                kstein = gof.KernelSteinTest(p, k, alpha=args.alpha, n_simulate=1000, seed=r)
                kstein_result = kstein.perform_test(data)
            return {'test_result': kstein_result, 'time_secs': t.secs}

        def job_fssdJ1q_opt(p, data_source, tr, te, r, J=1, null_sim=None):
            """
            FSSD with optimization on tr. Test on te. Use a Gaussian kernel.
            """
            if null_sim is None:
                null_sim = gof.FSSDH0SimCovObs(n_simulate=2000, seed=r)

            Xtr = tr.data()
            with util.ContextTimer() as t:
                # Use grid search to initialize the gwidth
                n_gwidth_cand = 5
                gwidth_factors = 2.0 ** np.linspace(-3, 3, n_gwidth_cand)
                med2 = util.meddistance(Xtr, 1000) ** 2
                print(med2)
                k = kernel.KGauss(med2 * 2)
                # fit a Gaussian to the data and draw to initialize V0
                V0 = util.fit_gaussian_draw(Xtr, J, seed=r + 1, reg=1e-6)
                list_gwidth = np.hstack(((med2) * gwidth_factors))
                besti, objs = gof.GaussFSSD.grid_search_gwidth(p, tr, V0, list_gwidth)
                gwidth = list_gwidth[besti]
                assert util.is_real_num(gwidth), 'gwidth not real. Was %s' % str(gwidth)
                assert gwidth > 0, 'gwidth not positive. Was %.3g' % gwidth
                print('After grid search, gwidth=%.3g' % gwidth)

                ops = {
                    'reg': 1e-2,
                    'max_iter': 40,
                    'tol_fun': 1e-4,
                    'disp': True,
                    'locs_bounds_frac': 10.0,
                    'gwidth_lb': 1e-1,
                    'gwidth_ub': 1e4,
                }

                V_opt, gwidth_opt, info = gof.GaussFSSD.optimize_locs_widths(p, tr,
                                                                             gwidth, V0, **ops)
                # Use the optimized parameters to construct a test
                k_opt = kernel.KGauss(gwidth_opt)
                fssd_opt = gof.FSSD(p, k_opt, V_opt, null_sim=null_sim, alpha=args.alpha)
                fssd_opt_result = fssd_opt.perform_test(te)
            return {'test_result': fssd_opt_result, 'time_secs': t.secs,
                    'goftest': fssd_opt, 'opt_info': info,
                    }

        def job_fssdJ5q_opt(p, data_source, tr, te, r):
            return job_fssdJ1q_opt(p, data_source, tr, te, r, J=5)


        if "rbm" in args.test:
            if args.test_type == "mmd":
                q = kdata.DSGaussBernRBM(np.array(q_dist.B.detach().numpy()),
                                         np.array(q_dist.b.detach().numpy()[0]),
                                         np.array(q_dist.c.detach().numpy()[0]))
            else:
                q = GaussBernRBM(np.array(q_dist.B.detach().numpy()),
                                         np.array(q_dist.b.detach().numpy()[0]),
                                         np.array(q_dist.c.detach().numpy()[0]))
            p = kdata.DSGaussBernRBM(np.array(p_dist.B.detach().numpy()),
                                     np.array(p_dist.b.detach().numpy()[0]),
                                     np.array(p_dist.c.detach().numpy()[0]))
        elif args.test == "laplace-gaussian":
            mu = np.zeros((args.dim_x,))
            std = np.eye(args.dim_x)
            q = density.Normal(mu, std)
            p = kdata.DSLaplace(args.dim_x, scale=1/(2. ** .5))
        elif args.test == "gaussian-pert":
            mu = np.zeros((args.dim_x,))
            std = np.eye(args.dim_x)
            q = density.Normal(mu, std)
            p = kdata.DSNormal(mu, std)

        data_train = p.sample(args.n_train, args.seed)
        data_test = p.sample(args.n_test, args.seed + 1)


        if args.test_type == "fssd":
            result = job_fssdJ5q_opt(q, p, data_train, data_test, r=args.seed)
        elif args.test_type == "ksd":
            result = job_kstein_med(q, p, data_train, data_test, r=args.seed)
        elif args.test_type == "lksd":
            result = job_lin_kstein_med(q, p, data_train, data_test, r=args.seed)
        elif args.test_type == "mmd":
            model_sample = q.sample(args.n_train + args.n_test, args.seed + 2)
            result = job_mmd_opt(q, p, data_train, data_test, args.seed, model_sample)
        print(result['test_result'])
        reject = result['test_result']['h0_rejected']
        try_make_dirs(os.path.dirname(args.save))
        with open(args.save, 'w') as f:
            if reject:
                print("reject")
                f.write("reject")
            else:
                print("accept")
                f.write("accept")
Beispiel #3
0
    dload_train, dload_test = get_data(args)

    # logger
    utils.makedirs(args.save)
    logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'),
                              filepath=os.path.abspath(__file__))

    logger.info(args)

    if args.arch == "mlp":
        if args.quadratic:
            net = networks.QuadraticMLP(args.data_dim, n_hid=args.hidden_dim)
        else:
            net = networks.SmallMLP(args.data_dim,
                                    n_hid=args.hidden_dim,
                                    dropout=args.dropout)

        critic = networks.SmallMLP(args.data_dim,
                                   n_out=args.data_dim,
                                   n_hid=args.hidden_dim,
                                   dropout=args.dropout)
    elif args.arch == "mlp-large":
        net = networks.LargeMLP(args.data_dim,
                                n_hid=args.hidden_dim,
                                dropout=args.dropout)
        critic = networks.LargeMLP(args.data_dim,
                                   n_out=args.data_dim,
                                   n_hid=args.hidden_dim,
                                   dropout=args.dropout)
    else:
Beispiel #4
0
        def forward(self, x):
            s = x @ self.B
            return self.base_dist.log_prob(s).sum(1)

        def log_prob(self, x):
            logp_plus_Z = self(x)
            cov = self.B.det().abs().log()
            #cov = self.B.logdet()
            return logp_plus_Z + cov

        def sample(self, n):
            s = self.base_dist.sample((n, self.dim)).to(device)
            x = s @ self.A
            return x

    kernel_net = networks.SmallMLP(args.dim, n_out=args.dim)
    if args.sn:
        kernel_net.apply(apply_spectral_norm)

    np.random.seed(args.seed)
    trueICA = ICA(args.dim, reverse=False)
    modelICA = ICA(args.dim)

    logger.info(trueICA.B)
    logger.info(modelICA.B)

    logger.info(trueICA.A)
    logger.info(modelICA.A)

    modelICA.to(device)
    trueICA.to(device)
Beispiel #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data',
                        choices=[
                            'swissroll', '8gaussians', 'pinwheel', 'circles',
                            'moons', '2spirals', 'checkerboard', 'rings'
                        ],
                        type=str,
                        default='moons')
    parser.add_argument('--niters', type=int, default=10000)
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--test_batch_size', type=int, default=1000)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--critic_weight_decay', type=float, default=0)
    parser.add_argument('--save', type=str, default='/tmp/test_lsd')
    parser.add_argument('--mode',
                        type=str,
                        default="lsd",
                        choices=['lsd', 'sm'])
    parser.add_argument('--viz_freq', type=int, default=100)
    parser.add_argument('--save_freq', type=int, default=10000)
    parser.add_argument('--log_freq', type=int, default=100)
    parser.add_argument('--base_dist', action="store_true")
    parser.add_argument('--c_iters', type=int, default=5)
    parser.add_argument('--l2', type=float, default=10.)
    parser.add_argument('--exact_trace', action="store_true")
    parser.add_argument('--n_steps', type=int, default=10)
    args = parser.parse_args()

    # logger
    utils.makedirs(args.save)
    logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'),
                              filepath=os.path.abspath(__file__))
    logger.info(args)

    # fit a gaussian to the training data
    init_size = 1000
    init_batch = sample_data(args, init_size).requires_grad_()
    mu, std = init_batch.mean(0), init_batch.std(0)
    base_dist = distributions.Normal(mu, std)

    # neural netz
    critic = networks.SmallMLP(2, n_out=2)
    net = networks.SmallMLP(2)

    ebm = EBM(net, base_dist if args.base_dist else None)
    ebm.to(device)
    critic.to(device)

    # for sampling
    init_fn = lambda: base_dist.sample_n(args.test_batch_size)
    cov = utils.cov(init_batch)
    sampler = HMCSampler(ebm,
                         .3,
                         5,
                         init_fn,
                         device=device,
                         covariance_matrix=cov)

    logger.info(ebm)
    logger.info(critic)

    # optimizers
    optimizer = optim.Adam(ebm.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay,
                           betas=(.0, .999))
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.lr,
                                  betas=(.0, .999),
                                  weight_decay=args.critic_weight_decay)

    time_meter = utils.RunningAverageMeter(0.98)
    loss_meter = utils.RunningAverageMeter(0.98)

    ebm.train()
    end = time.time()
    for itr in range(args.niters):

        optimizer.zero_grad()
        critic_optimizer.zero_grad()

        x = sample_data(args, args.batch_size)
        x.requires_grad_()

        if args.mode == "lsd":
            # our method

            # compute dlogp(x)/dx
            logp_u = ebm(x)
            sq = keep_grad(logp_u.sum(), x)
            fx = critic(x)
            # compute (dlogp(x)/dx)^T * f(x)
            sq_fx = (sq * fx).sum(-1)

            # compute/estimate Tr(df/dx)
            if args.exact_trace:
                tr_dfdx = exact_jacobian_trace(fx, x)
            else:
                tr_dfdx = approx_jacobian_trace(fx, x)

            stats = (sq_fx + tr_dfdx)
            loss = stats.mean()  # estimate of S(p, q)
            l2_penalty = (
                fx * fx).sum(1).mean() * args.l2  # penalty to enforce f \in F

            # adversarial!
            if args.c_iters > 0 and itr % (args.c_iters + 1) != 0:
                (-1. * loss + l2_penalty).backward()
                critic_optimizer.step()
            else:
                loss.backward()
                optimizer.step()

        elif args.mode == "sm":
            # score matching for reference
            fx = ebm(x)
            dfdx = torch.autograd.grad(fx.sum(),
                                       x,
                                       retain_graph=True,
                                       create_graph=True)[0]
            eps = torch.randn_like(dfdx)  # use hutchinson here as well
            epsH = torch.autograd.grad(dfdx,
                                       x,
                                       grad_outputs=eps,
                                       create_graph=True,
                                       retain_graph=True)[0]

            trH = (epsH * eps).sum(1)
            norm_s = (dfdx * dfdx).sum(1)

            loss = (trH + .5 * norm_s).mean()
            loss.backward()
            optimizer.step()
        else:
            assert False

        loss_meter.update(loss.item())
        time_meter.update(time.time() - end)

        if itr % args.log_freq == 0:
            log_message = (
                'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.4f}({:.4f})'.
                format(itr, time_meter.val, time_meter.avg, loss_meter.val,
                       loss_meter.avg))
            logger.info(log_message)

        if itr % args.save_freq == 0 or itr == args.niters:
            ebm.cpu()
            utils.makedirs(args.save)
            torch.save({
                'args': args,
                'state_dict': ebm.state_dict(),
            }, os.path.join(args.save, 'checkpt.pth'))
            ebm.to(device)

        if itr % args.viz_freq == 0:
            # plot dat
            plt.clf()
            npts = 100
            p_samples = toy_data.inf_train_gen(args.data, batch_size=npts**2)
            q_samples = sampler.sample(args.n_steps)

            ebm.cpu()

            x_enc = critic(x)
            xes = x_enc.detach().cpu().numpy()
            trans = xes.min()
            scale = xes.max() - xes.min()
            xes = (xes - trans) / scale * 8 - 4

            plt.figure(figsize=(4, 4))
            visualize_transform(
                [p_samples, q_samples.detach().cpu().numpy(), xes],
                ["data", "model", "embed"], [ebm], ["model"],
                npts=npts)

            fig_filename = os.path.join(args.save, 'figs',
                                        '{:04d}.png'.format(itr))
            utils.makedirs(os.path.dirname(fig_filename))
            plt.savefig(fig_filename)
            plt.close()

            ebm.to(device)
        end = time.time()

    logger.info('Training has finished, can I get a yeet?')