def main(args): utils.makedirs(args.save) net = networks.SmallMLP(2, 2, n_hid=args.hid) if args.dataset == "moons": Xf, Y = datasets.make_moons(1000, noise=.1) Xfte, Yte = datasets.make_moons(1000, noise=.1) Xoh, Xohte = [], [] elif args.dataset == "circles": Xf, Y = datasets.make_circles(1000, noise=.03) Xfte, Yte = datasets.make_circles(1000, noise=.03) Xoh, Xohte = [], [] elif args.dataset == "adult": with open("data/adult/adult.data", 'r') as f: Xf, Xoh, Y = data_utils.load_adult() with open("data/adult/adult.test", 'r') as f: Xfte, Xohte, Yte = data_utils.load_adult() else: raise NotImplementedError Xf = Xf.astype(np.float32) Xfl, Xohl, Yl = [], [], [] if args.n_labels_per_class != -1: Xfl.extend(Xf[Y == 0][:args.n_labels_per_class]) Xfl.extend(Xf[Y == 1][:args.n_labels_per_class]) Yl.extend([0] * args.n_labels_per_class) Yl.extend([1] * args.n_labels_per_class) if Xoh is not None: Xohl.extend(Xf[Y == 0][:args.n_labels_per_class]) Xohl.extend(Xf[Y == 1][:args.n_labels_per_class]) else: Xfl, Xohl, Yl = Xf, Xoh, Y def plot_data(fname="data.png"): plt.clf() decision_boundary(net, Xf) plt.scatter(Xf[:, 0], Xf[:, 1], c='grey') plt.scatter(Xfl[:args.n_labels_per_class, 0], Xfl[:args.n_labels_per_class, 1], c='r') plt.scatter(Xfl[args.n_labels_per_class:, 0], Xfl[args.n_labels_per_class:, 1], c='b') plt.savefig("{}/{}".format(args.save, fname)) optim = torch.optim.Adam(params=net.parameters(), lr=args.lr) xl = torch.from_numpy(Xl).to(device) yl = torch.from_numpy(np.array(Yl)).to(device) x_te, y_te = torch.from_numpy(Xte).float(), torch.from_numpy(Yte) inds = list(range(X.shape[0])) for i in range(args.n_iters): batch_inds = np.random.choice(inds, args.batch_size, replace=False) x = X[batch_inds] x = torch.from_numpy(x).to(device).requires_grad_() logits = net(xl) clf_loss = nn.CrossEntropyLoss(reduction='none')(logits, yl).mean() logits_u = net(x) logpx_plus_Z = logits_u.logsumexp(1) sp = utils.keep_grad(logpx_plus_Z.sum(), x) e = torch.randn_like(sp) eH = utils.keep_grad(sp, x, grad_outputs=e) trH = (eH * e).sum(-1) sm_loss = trH + .5 * (sp**2).sum(-1) sm_loss = sm_loss.mean() loss = (1 - args.sm_lam) * clf_loss + args.sm_lam * sm_loss optim.zero_grad() loss.backward() optim.step() if i % 100 == 0: if args.dataset in ("rings", "moons"): plot_data("data_{}.png".format(i)) te_logits = net(x_te.float()) te_preds = torch.argmax(te_logits, 1) te_acc = (te_preds == y_te).float().mean() print("Iter {}: Clf Loss = {}, SM Loss = {} | Test Accuracy = {}". format(i, clf_loss.item(), sm_loss.item(), te_acc.item()))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--test', choices=['gaussian-laplace', 'laplace-gaussian', 'gaussian-pert', 'rbm-pert', 'rbm-pert1'], type=str) parser.add_argument('--dim_x', type=int, default=50) parser.add_argument('--dim_h', type=int, default=40) parser.add_argument('--sigma_pert', type=float, default=.02) parser.add_argument('--maximize_power', action="store_true") parser.add_argument('--maximize_adj_mean', action="store_true") parser.add_argument('--val_power', action="store_true") parser.add_argument('--val_adj_mean', action="store_true") parser.add_argument('--dropout', action="store_true") parser.add_argument('--alpha', type=float, default=.05) parser.add_argument('--save', type=str, default='/tmp/test_ksd') parser.add_argument('--test_type', type=str, default='mine') parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--l2', type=float, default=0.) parser.add_argument('--num_const', type=float, default=1e-6) parser.add_argument('--log_freq', type=int, default=10) parser.add_argument('--val_freq', type=int, default=100) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--seed', type=int, default=100001) parser.add_argument('--n_train', type=int, default=1000) parser.add_argument('--n_val', type=int, default=1000) parser.add_argument('--n_test', type=int, default=1000) parser.add_argument('--n_iters', type=int, default=100001) parser.add_argument('--batch_size', type=int, default=100) parser.add_argument('--test_batch_size', type=int, default=1000) parser.add_argument('--test_burn_in', type=int, default=0) parser.add_argument('--mode', type=str, default="fs") parser.add_argument('--viz_freq', type=int, default=100) parser.add_argument('--save_freq', type=int, default=10000) parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--base_dist', action="store_true") parser.add_argument('--t_iters', type=int, default=5) parser.add_argument('--k_dim', type=int, default=1) parser.add_argument('--sn', type=float, default=-1.) parser.add_argument('--exact_trace', action="store_true") parser.add_argument('--quadratic', action="store_true") parser.add_argument('--n_steps', type=int, default=100) parser.add_argument('--both_scaled', action="store_true") args = parser.parse_args() device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu') torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) if args.test == "gaussian-laplace": mu = torch.zeros((args.dim_x,)) std = torch.ones((args.dim_x,)) p_dist = Gaussian(mu, std) q_dist = Laplace(mu, std) elif args.test == "laplace-gaussian": mu = torch.zeros((args.dim_x,)) std = torch.ones((args.dim_x,)) q_dist = Gaussian(mu, std) p_dist = Laplace(mu, std / (2 ** .5)) elif args.test == "gaussian-pert": mu = torch.zeros((args.dim_x,)) std = torch.ones((args.dim_x,)) p_dist = Gaussian(mu, std) q_dist = Gaussian(mu + torch.randn_like(mu) * args.sigma_pert, std) elif args.test == "rbm-pert1": B = randb((args.dim_x, args.dim_h)) * 2. - 1. c = torch.randn((1, args.dim_h)) b = torch.randn((1, args.dim_x)) p_dist = GaussianBernoulliRBM(B, b, c) B2 = B.clone() B2[0, 0] += torch.randn_like(B2[0, 0]) * args.sigma_pert q_dist = GaussianBernoulliRBM(B2, b, c) else: # args.test == "rbm-pert" B = randb((args.dim_x, args.dim_h)) * 2. - 1. c = torch.randn((1, args.dim_h)) b = torch.randn((1, args.dim_x)) p_dist = GaussianBernoulliRBM(B, b, c) q_dist = GaussianBernoulliRBM(B + torch.randn_like(B) * args.sigma_pert, b, c) # run mah shiiiiit if args.test_type == "mine": import numpy as np data = p_dist.sample(args.n_train + args.n_val + args.n_test).detach() data_train = data[:args.n_train] data_rest = data[args.n_train:] data_val = data_rest[:args.n_val].requires_grad_() data_test = data_rest[args.n_val:].requires_grad_() assert data_test.size(0) == args.n_test critic = networks.SmallMLP(args.dim_x, n_out=args.dim_x, n_hid=300, dropout=args.dropout) optimizer = optim.Adam(critic.parameters(), lr=args.lr, weight_decay=args.weight_decay) def stein_discrepency(x, exact=False): if "rbm" in args.test: sq = q_dist.score_function(x) else: logq_u = q_dist(x) sq = keep_grad(logq_u.sum(), x) fx = critic(x) if args.dim_x == 1: fx = fx[:, None] sq_fx = (sq * fx).sum(-1) if exact: tr_dfdx = exact_jacobian_trace(fx, x) else: tr_dfdx = approx_jacobian_trace(fx, x) norms = (fx * fx).sum(1) stats = (sq_fx + tr_dfdx) return stats, norms # training phase best_val = -np.inf validation_metrics = [] test_statistics = [] critic.train() for itr in range(args.n_iters): optimizer.zero_grad() x = sample_batch(data_train, args.batch_size) x = x.to(device) x.requires_grad_() stats, norms = stein_discrepency(x) mean, std = stats.mean(), stats.std() l2_penalty = norms.mean() * args.l2 if args.maximize_power: loss = -1. * mean / (std + args.num_const) + l2_penalty elif args.maximize_adj_mean: loss = -1. * mean + std + l2_penalty else: loss = -1. * mean + l2_penalty loss.backward() optimizer.step() if itr % args.log_freq == 0: print("Iter {}, Loss = {}, Mean = {}, STD = {}, L2 {}".format(itr, loss.item(), mean.item(), std.item(), l2_penalty.item())) if itr % args.val_freq == 0: critic.eval() val_stats, _ = stein_discrepency(data_val, exact=True) test_stats, _ = stein_discrepency(data_test, exact=True) print("Val: {} +/- {}".format(val_stats.mean().item(), val_stats.std().item())) print("Test: {} +/- {}".format(test_stats.mean().item(), test_stats.std().item())) if args.val_power: validation_metric = val_stats.mean() / (val_stats.std() + args.num_const) elif args.val_adj_mean: validation_metric = val_stats.mean() - val_stats.std() else: validation_metric = val_stats.mean() test_statistic = test_stats.mean() / (test_stats.std() + args.num_const) if validation_metric > best_val: print("Iter {}, Validation Metric = {} > {}, Test Statistic = {}, Current Best!".format(itr, validation_metric.item(), best_val, test_statistic.item())) best_val = validation_metric.item() else: print("Iter {}, Validation Metric = {}, Test Statistic = {}, Not best {}".format(itr, validation_metric.item(), test_statistic.item(), best_val)) validation_metrics.append(validation_metric.item()) test_statistics.append(test_statistic) critic.train() best_ind = np.argmax(validation_metrics) best_test = test_statistics[best_ind] print("Best val is {}, best test is {}".format(best_val, best_test)) test_stat = best_test * args.n_test ** .5 threshold = distributions.Normal(0, 1).icdf(torch.ones((1,)) * (1. - args.alpha)).item() try_make_dirs(os.path.dirname(args.save)) with open(args.save, 'w') as f: f.write(str(test_stat) + '\n') if test_stat > threshold: print("{} > {}, rejct Null".format(test_stat, threshold)) f.write("reject") else: print("{} <= {}, accept Null".format(test_stat, threshold)) f.write("accept") # baselines else: import autograd.numpy as np #import kgof.goftest as gof import mygoftest as gof import kgof.util as util import kgof.kernel as kernel import kgof.density as density import kgof.data as kdata class GaussBernRBM(density.UnnormalizedDensity): """ Gaussian-Bernoulli Restricted Boltzmann Machine. The joint density takes the form p(x, h) = Z^{-1} exp(0.5*x^T B h + b^T x + c^T h - 0.5||x||^2) where h is a vector of {-1, 1}. """ def __init__(self, B, b, c): """ B: a dx x dh matrix b: a numpy array of length dx c: a numpy array of length dh """ dh = len(c) dx = len(b) assert B.shape[0] == dx assert B.shape[1] == dh assert dx > 0 assert dh > 0 self.B = B self.b = b self.c = c def log_den(self, X): B = self.B b = self.b c = self.c XBC = 0.5 * np.dot(X, B) + c unden = np.dot(X, b) - 0.5 * np.sum(X ** 2, 1) + np.sum(np.log(np.exp(XBC) + np.exp(-XBC)), 1) assert len(unden) == X.shape[0] return unden def grad_log(self, X): # """ # Evaluate the gradients (with respect to the input) of the log density at # each of the n points in X. This is the score function. # X: n x d numpy array. """ Evaluate the gradients (with respect to the input) of the log density at each of the n points in X. This is the score function. X: n x d numpy array. Return an n x d numpy array of gradients. """ XB = np.dot(X, self.B) Y = 0.5 * XB + self.c # E2y = np.exp(2*Y) # n x dh # Phi = old_div((E2y-1.0),(E2y+1)) Phi = np.tanh(Y) # n x dx T = np.dot(Phi, 0.5 * self.B.T) S = self.b - X + T return S def get_datasource(self, burnin=2000): return data.DSGaussBernRBM(self.B, self.b, self.c, burnin=burnin) def dim(self): return len(self.b) def job_lin_kstein_med(p, data_source, tr, te, r): """ Linear-time version of the kernel Stein discrepancy test of Liu et al., 2016 and Chwialkowski et al., 2016. Use full sample. """ # full data data = tr + te X = data.data() with util.ContextTimer() as t: # median heuristic med = util.meddistance(X, subsample=1000) k = kernel.KGauss(med ** 2) lin_kstein = gof.LinearKernelSteinTest(p, k, alpha=args.alpha, seed=r) lin_kstein_result = lin_kstein.perform_test(data) return {'test_result': lin_kstein_result, 'time_secs': t.secs} def job_mmd_opt(p, data_source, tr, te, r, model_sample): # full data data = tr + te X = data.data() with util.ContextTimer() as t: mmd = gof.QuadMMDGofOpt(p, alpha=args.alpha, seed=r) mmd_result = mmd.perform_test(data, model_sample) return {'test_result': mmd_result, 'time_secs': t.secs} def job_kstein_med(p, data_source, tr, te, r): """ Kernel Stein discrepancy test of Liu et al., 2016 and Chwialkowski et al., 2016. Use full sample. Use Gaussian kernel. """ # full data data = tr + te X = data.data() with util.ContextTimer() as t: # median heuristic med = util.meddistance(X, subsample=1000) k = kernel.KGauss(med ** 2) kstein = gof.KernelSteinTest(p, k, alpha=args.alpha, n_simulate=1000, seed=r) kstein_result = kstein.perform_test(data) return {'test_result': kstein_result, 'time_secs': t.secs} def job_fssdJ1q_opt(p, data_source, tr, te, r, J=1, null_sim=None): """ FSSD with optimization on tr. Test on te. Use a Gaussian kernel. """ if null_sim is None: null_sim = gof.FSSDH0SimCovObs(n_simulate=2000, seed=r) Xtr = tr.data() with util.ContextTimer() as t: # Use grid search to initialize the gwidth n_gwidth_cand = 5 gwidth_factors = 2.0 ** np.linspace(-3, 3, n_gwidth_cand) med2 = util.meddistance(Xtr, 1000) ** 2 print(med2) k = kernel.KGauss(med2 * 2) # fit a Gaussian to the data and draw to initialize V0 V0 = util.fit_gaussian_draw(Xtr, J, seed=r + 1, reg=1e-6) list_gwidth = np.hstack(((med2) * gwidth_factors)) besti, objs = gof.GaussFSSD.grid_search_gwidth(p, tr, V0, list_gwidth) gwidth = list_gwidth[besti] assert util.is_real_num(gwidth), 'gwidth not real. Was %s' % str(gwidth) assert gwidth > 0, 'gwidth not positive. Was %.3g' % gwidth print('After grid search, gwidth=%.3g' % gwidth) ops = { 'reg': 1e-2, 'max_iter': 40, 'tol_fun': 1e-4, 'disp': True, 'locs_bounds_frac': 10.0, 'gwidth_lb': 1e-1, 'gwidth_ub': 1e4, } V_opt, gwidth_opt, info = gof.GaussFSSD.optimize_locs_widths(p, tr, gwidth, V0, **ops) # Use the optimized parameters to construct a test k_opt = kernel.KGauss(gwidth_opt) fssd_opt = gof.FSSD(p, k_opt, V_opt, null_sim=null_sim, alpha=args.alpha) fssd_opt_result = fssd_opt.perform_test(te) return {'test_result': fssd_opt_result, 'time_secs': t.secs, 'goftest': fssd_opt, 'opt_info': info, } def job_fssdJ5q_opt(p, data_source, tr, te, r): return job_fssdJ1q_opt(p, data_source, tr, te, r, J=5) if "rbm" in args.test: if args.test_type == "mmd": q = kdata.DSGaussBernRBM(np.array(q_dist.B.detach().numpy()), np.array(q_dist.b.detach().numpy()[0]), np.array(q_dist.c.detach().numpy()[0])) else: q = GaussBernRBM(np.array(q_dist.B.detach().numpy()), np.array(q_dist.b.detach().numpy()[0]), np.array(q_dist.c.detach().numpy()[0])) p = kdata.DSGaussBernRBM(np.array(p_dist.B.detach().numpy()), np.array(p_dist.b.detach().numpy()[0]), np.array(p_dist.c.detach().numpy()[0])) elif args.test == "laplace-gaussian": mu = np.zeros((args.dim_x,)) std = np.eye(args.dim_x) q = density.Normal(mu, std) p = kdata.DSLaplace(args.dim_x, scale=1/(2. ** .5)) elif args.test == "gaussian-pert": mu = np.zeros((args.dim_x,)) std = np.eye(args.dim_x) q = density.Normal(mu, std) p = kdata.DSNormal(mu, std) data_train = p.sample(args.n_train, args.seed) data_test = p.sample(args.n_test, args.seed + 1) if args.test_type == "fssd": result = job_fssdJ5q_opt(q, p, data_train, data_test, r=args.seed) elif args.test_type == "ksd": result = job_kstein_med(q, p, data_train, data_test, r=args.seed) elif args.test_type == "lksd": result = job_lin_kstein_med(q, p, data_train, data_test, r=args.seed) elif args.test_type == "mmd": model_sample = q.sample(args.n_train + args.n_test, args.seed + 2) result = job_mmd_opt(q, p, data_train, data_test, args.seed, model_sample) print(result['test_result']) reject = result['test_result']['h0_rejected'] try_make_dirs(os.path.dirname(args.save)) with open(args.save, 'w') as f: if reject: print("reject") f.write("reject") else: print("accept") f.write("accept")
dload_train, dload_test = get_data(args) # logger utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) if args.arch == "mlp": if args.quadratic: net = networks.QuadraticMLP(args.data_dim, n_hid=args.hidden_dim) else: net = networks.SmallMLP(args.data_dim, n_hid=args.hidden_dim, dropout=args.dropout) critic = networks.SmallMLP(args.data_dim, n_out=args.data_dim, n_hid=args.hidden_dim, dropout=args.dropout) elif args.arch == "mlp-large": net = networks.LargeMLP(args.data_dim, n_hid=args.hidden_dim, dropout=args.dropout) critic = networks.LargeMLP(args.data_dim, n_out=args.data_dim, n_hid=args.hidden_dim, dropout=args.dropout) else:
def forward(self, x): s = x @ self.B return self.base_dist.log_prob(s).sum(1) def log_prob(self, x): logp_plus_Z = self(x) cov = self.B.det().abs().log() #cov = self.B.logdet() return logp_plus_Z + cov def sample(self, n): s = self.base_dist.sample((n, self.dim)).to(device) x = s @ self.A return x kernel_net = networks.SmallMLP(args.dim, n_out=args.dim) if args.sn: kernel_net.apply(apply_spectral_norm) np.random.seed(args.seed) trueICA = ICA(args.dim, reverse=False) modelICA = ICA(args.dim) logger.info(trueICA.B) logger.info(modelICA.B) logger.info(trueICA.A) logger.info(modelICA.A) modelICA.to(device) trueICA.to(device)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--data', choices=[ 'swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings' ], type=str, default='moons') parser.add_argument('--niters', type=int, default=10000) parser.add_argument('--batch_size', type=int, default=100) parser.add_argument('--test_batch_size', type=int, default=1000) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--critic_weight_decay', type=float, default=0) parser.add_argument('--save', type=str, default='/tmp/test_lsd') parser.add_argument('--mode', type=str, default="lsd", choices=['lsd', 'sm']) parser.add_argument('--viz_freq', type=int, default=100) parser.add_argument('--save_freq', type=int, default=10000) parser.add_argument('--log_freq', type=int, default=100) parser.add_argument('--base_dist', action="store_true") parser.add_argument('--c_iters', type=int, default=5) parser.add_argument('--l2', type=float, default=10.) parser.add_argument('--exact_trace', action="store_true") parser.add_argument('--n_steps', type=int, default=10) args = parser.parse_args() # logger utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) # fit a gaussian to the training data init_size = 1000 init_batch = sample_data(args, init_size).requires_grad_() mu, std = init_batch.mean(0), init_batch.std(0) base_dist = distributions.Normal(mu, std) # neural netz critic = networks.SmallMLP(2, n_out=2) net = networks.SmallMLP(2) ebm = EBM(net, base_dist if args.base_dist else None) ebm.to(device) critic.to(device) # for sampling init_fn = lambda: base_dist.sample_n(args.test_batch_size) cov = utils.cov(init_batch) sampler = HMCSampler(ebm, .3, 5, init_fn, device=device, covariance_matrix=cov) logger.info(ebm) logger.info(critic) # optimizers optimizer = optim.Adam(ebm.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(.0, .999)) critic_optimizer = optim.Adam(critic.parameters(), lr=args.lr, betas=(.0, .999), weight_decay=args.critic_weight_decay) time_meter = utils.RunningAverageMeter(0.98) loss_meter = utils.RunningAverageMeter(0.98) ebm.train() end = time.time() for itr in range(args.niters): optimizer.zero_grad() critic_optimizer.zero_grad() x = sample_data(args, args.batch_size) x.requires_grad_() if args.mode == "lsd": # our method # compute dlogp(x)/dx logp_u = ebm(x) sq = keep_grad(logp_u.sum(), x) fx = critic(x) # compute (dlogp(x)/dx)^T * f(x) sq_fx = (sq * fx).sum(-1) # compute/estimate Tr(df/dx) if args.exact_trace: tr_dfdx = exact_jacobian_trace(fx, x) else: tr_dfdx = approx_jacobian_trace(fx, x) stats = (sq_fx + tr_dfdx) loss = stats.mean() # estimate of S(p, q) l2_penalty = ( fx * fx).sum(1).mean() * args.l2 # penalty to enforce f \in F # adversarial! if args.c_iters > 0 and itr % (args.c_iters + 1) != 0: (-1. * loss + l2_penalty).backward() critic_optimizer.step() else: loss.backward() optimizer.step() elif args.mode == "sm": # score matching for reference fx = ebm(x) dfdx = torch.autograd.grad(fx.sum(), x, retain_graph=True, create_graph=True)[0] eps = torch.randn_like(dfdx) # use hutchinson here as well epsH = torch.autograd.grad(dfdx, x, grad_outputs=eps, create_graph=True, retain_graph=True)[0] trH = (epsH * eps).sum(1) norm_s = (dfdx * dfdx).sum(1) loss = (trH + .5 * norm_s).mean() loss.backward() optimizer.step() else: assert False loss_meter.update(loss.item()) time_meter.update(time.time() - end) if itr % args.log_freq == 0: log_message = ( 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.4f}({:.4f})'. format(itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg)) logger.info(log_message) if itr % args.save_freq == 0 or itr == args.niters: ebm.cpu() utils.makedirs(args.save) torch.save({ 'args': args, 'state_dict': ebm.state_dict(), }, os.path.join(args.save, 'checkpt.pth')) ebm.to(device) if itr % args.viz_freq == 0: # plot dat plt.clf() npts = 100 p_samples = toy_data.inf_train_gen(args.data, batch_size=npts**2) q_samples = sampler.sample(args.n_steps) ebm.cpu() x_enc = critic(x) xes = x_enc.detach().cpu().numpy() trans = xes.min() scale = xes.max() - xes.min() xes = (xes - trans) / scale * 8 - 4 plt.figure(figsize=(4, 4)) visualize_transform( [p_samples, q_samples.detach().cpu().numpy(), xes], ["data", "model", "embed"], [ebm], ["model"], npts=npts) fig_filename = os.path.join(args.save, 'figs', '{:04d}.png'.format(itr)) utils.makedirs(os.path.dirname(fig_filename)) plt.savefig(fig_filename) plt.close() ebm.to(device) end = time.time() logger.info('Training has finished, can I get a yeet?')