def do_test_per_param_optim(self, fixed_param, free_param): pyro.clear_param_store() def model(): prior_dist = Normal(self.loc0, torch.pow(self.lam0, -0.5)) loc_latent = pyro.sample("loc_latent", prior_dist) x_dist = Normal(loc_latent, torch.pow(self.lam, -0.5)) pyro.sample("obs", x_dist, obs=self.data) return loc_latent def guide(): loc_q = pyro.param("loc_q", torch.zeros(1, requires_grad=True)) log_sig_q = pyro.param("log_sig_q", torch.zeros(1, requires_grad=True)) sig_q = torch.exp(log_sig_q) pyro.sample("loc_latent", Normal(loc_q, sig_q)) def optim_params(param_name): if param_name == fixed_param: return {"lr": 0.00} elif param_name == free_param: return {"lr": 0.01} adam = optim.Adam(optim_params) adam2 = optim.Adam(optim_params) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi2 = SVI(model, guide, adam2, loss=TraceGraph_ELBO()) svi.step() adam_initial_step_count = list( adam.get_state()["loc_q"]["state"].items())[0][1]["step"] with TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, "optimizer_state.pt") adam.save(filename) svi.step() adam_final_step_count = list( adam.get_state()["loc_q"]["state"].items())[0][1]["step"] adam2.load(filename) svi2.step() adam2_step_count_after_load_and_step = list( adam2.get_state()["loc_q"]["state"].items())[0][1]["step"] assert adam_initial_step_count == 1 assert adam_final_step_count == 2 assert adam2_step_count_after_load_and_step == 2 free_param_unchanged = torch.equal( pyro.param(free_param).data, torch.zeros(1)) fixed_param_unchanged = torch.equal( pyro.param(fixed_param).data, torch.zeros(1)) assert fixed_param_unchanged and not free_param_unchanged
def do_test_per_param_optim(self, fixed_param, free_param): pyro.clear_param_store() def model(): prior_dist = Normal(self.loc0, torch.pow(self.lam0, -0.5)) loc_latent = pyro.sample("loc_latent", prior_dist) x_dist = Normal(loc_latent, torch.pow(self.lam, -0.5)) pyro.sample("obs", x_dist, obs=self.data) return loc_latent def guide(): loc_q = pyro.param("loc_q", torch.zeros(1, requires_grad=True)) log_sig_q = pyro.param("log_sig_q", torch.zeros(1, requires_grad=True)) sig_q = torch.exp(log_sig_q) pyro.sample("loc_latent", Normal(loc_q, sig_q)) def optim_params(module_name, param_name): if param_name == fixed_param: return {'lr': 0.00} elif param_name == free_param: return {'lr': 0.01} adam = optim.Adam(optim_params) adam2 = optim.Adam(optim_params) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) svi2 = SVI(model, guide, adam2, loss=TraceGraph_ELBO()) svi.step() adam_initial_step_count = list( adam.get_state()['loc_q']['state'].items())[0][1]['step'] adam.save('adam.unittest.save') svi.step() adam_final_step_count = list( adam.get_state()['loc_q']['state'].items())[0][1]['step'] adam2.load('adam.unittest.save') svi2.step() adam2_step_count_after_load_and_step = list( adam2.get_state()['loc_q']['state'].items())[0][1]['step'] assert adam_initial_step_count == 1 assert adam_final_step_count == 2 assert adam2_step_count_after_load_and_step == 2 free_param_unchanged = torch.equal( pyro.param(free_param).data, torch.zeros(1)) fixed_param_unchanged = torch.equal( pyro.param(fixed_param).data, torch.zeros(1)) assert fixed_param_unchanged and not free_param_unchanged
def do_elbo_test(self, reparameterized, n_steps, beta1, lr): logger.info(" - - - - - DO BETA-BERNOULLI ELBO TEST [repa = %s] - - - - - " % reparameterized) pyro.clear_param_store() Beta = dist.Beta if reparameterized else fakes.NonreparameterizedBeta def model(): p_latent = pyro.sample("p_latent", Beta(self.alpha0, self.beta0)) with pyro.plate("data", len(self.data)): pyro.sample("obs", dist.Bernoulli(p_latent), obs=self.data) return p_latent def guide(): alpha_q_log = pyro.param("alpha_q_log", self.log_alpha_n + 0.17) beta_q_log = pyro.param("beta_q_log", self.log_beta_n - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) p_latent = pyro.sample("p_latent", Beta(alpha_q, beta_q), infer=dict(baseline=dict(use_decaying_avg_baseline=True))) with pyro.plate("data", len(self.data)): pass return p_latent adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)}) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for k in range(n_steps): svi.step() alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n) beta_error = param_abs_error("beta_q_log", self.log_beta_n) if k % 500 == 0: logger.debug("alpha_error, beta_error: %.4f, %.4f" % (alpha_error, beta_error)) assert_equal(0.0, alpha_error, prec=0.03) assert_equal(0.0, beta_error, prec=0.04)
def test_dynamic_lr(scheduler, num_steps): pyro.clear_param_store() def model(): sample = pyro.sample('latent', Normal(torch.tensor(0.), torch.tensor(0.3))) return pyro.sample('obs', Normal(sample, torch.tensor(0.2)), obs=torch.tensor(0.1)) def guide(): loc = pyro.param('loc', torch.tensor(0.)) scale = pyro.param('scale', torch.tensor(0.5), constraint=constraints.positive) pyro.sample('latent', Normal(loc, scale)) svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO()) for epoch in range(2): scheduler.set_epoch(epoch) for _ in range(num_steps): svi.step() if epoch == 1: loc = pyro.param('loc').unconstrained() scale = pyro.param('scale').unconstrained() opt = scheduler.optim_objs[loc].optimizer assert opt.state_dict()['param_groups'][0]['lr'] == 0.02 assert opt.state_dict()['param_groups'][0]['initial_lr'] == 0.01 opt = scheduler.optim_objs[scale].optimizer assert opt.state_dict()['param_groups'][0]['lr'] == 0.02 assert opt.state_dict()['param_groups'][0]['initial_lr'] == 0.01 assert abs(pyro.param('loc').item()) > 1e-5 assert abs(pyro.param('scale').item() - 0.5) > 1e-5
def test_three_indep_iarange_at_different_depths_ok(): """ /\ /\ ia ia ia """ def model(): p = torch.tensor(0.5) inner_iarange = pyro.iarange("iarange1", 10, 5) for i in pyro.irange("irange0", 2): pyro.sample("x_%d" % i, dist.Bernoulli(p)) if i == 0: for j in pyro.irange("irange1", 2): with inner_iarange as ind: pyro.sample("y_%d" % j, dist.Bernoulli(p).expand_by([len(ind)])) elif i == 1: with inner_iarange as ind: pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind)])) def guide(): p = pyro.param("p", torch.tensor(0.5, requires_grad=True)) inner_iarange = pyro.iarange("iarange1", 10, 5) for i in pyro.irange("irange0", 2): pyro.sample("x_%d" % i, dist.Bernoulli(p)) if i == 0: for j in pyro.irange("irange1", 2): with inner_iarange as ind: pyro.sample("y_%d" % j, dist.Bernoulli(p).expand_by([len(ind)])) elif i == 1: with inner_iarange as ind: pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind)])) assert_ok(model, guide, TraceGraph_ELBO())
def do_inference(self, use_decaying_avg_baseline, tolerance=0.8): pyro.clear_param_store() optimizer_params = {"lr": 0.0005, "betas": (0.93, 0.999)} #optimizer = optim.Adam(optimizer_params) optimizer = optim.Adam(self.per_param_args) svi = SVI(self.model, self.nnguide, optimizer, loss=TraceGraph_ELBO()) print("Doing inference with use_decaying_avg_baseline=%s" % use_decaying_avg_baseline) ae, be = [], [] # do up to this many steps of inference for k in range(self.max_steps): svi.step(use_decaying_avg_baseline, torch.tensor([1.0])) if k % 100 == 0: print('.', end='') sys.stdout.flush() # compute the distance to the parameters of the true posterior alpha_error = param_abs_error("alpha_q", self.alpha_n) beta_error = param_abs_error("beta_q", self.beta_n) ae.append(alpha_error) be.append(beta_error) # stop inference early if we're close to the true posterior if alpha_error < tolerance and beta_error < tolerance: print("Stopped early at step: {}".format(k)) break return (ae, be)
def main(model, guide, args): # init if args.seed is not None: pyro.set_rng_seed(args.seed) logger = get_logger(args.log, __name__) logger.info(args) # load data X, true_counts = load_data() X_size = X.size(0) # setup svi def per_param_optim_args(module_name, param_name): def isBaselineParam(module_name, param_name): return 'bl_' in module_name or 'bl_' in param_name lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name)\ else args.learning_rate return {'lr': lr} opt = optim.Adam(per_param_optim_args) svi = SVI(model.main, guide.main, opt, loss=TraceGraph_ELBO()) # train times = [time.time()] logger.info(f"\nstep\t" + "epoch\t" + "elbo\t" + "time(sec)") for i in range(1, args.num_steps + 1): loss = svi.step(X) if (args.eval_frequency > 0 and i % args.eval_frequency == 0) or (i == 1): times.append(time.time()) logger.info(f"{i:06d}\t" f"{(i * args.batch_size) / X_size:.3f}\t" f"{-loss / X_size:.4f}\t" f"{times[-1]-times[-2]:.3f}")
def main(**kwargs): args = argparse.Namespace(**kwargs) args.batch_size = 64 pyro.set_rng_seed(args.seed) X, true_counts = load_data() X_size = X.size(0) def per_param_optim_args(module_name, param_name): def isBaselineParam(module_name, param_name): return 'bl_' in module_name or 'bl_' in param_name lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name)\ else args.learning_rate return {'lr': lr} adam = optim.Adam(per_param_optim_args) elbo = TraceGraph_ELBO() svi = SVI(air.model, air.guide, adam, loss=elbo) # wy t0 = time.time() for i in range(1, args.num_steps + 1): loss = svi.step(X) # wy if args.progress_every > 0 and i % args.progress_every == 0: print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format( i, (i * args.batch_size) / X_size, (time.time() - t0) / 3600, loss / X_size)) if args.eval_every > 0 and i % args.eval_every == 0: acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000) print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist()))
def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ elbo = TraceGraph_ELBO(vectorize_particles=False, num_particles=4) svi = SVI(self.model.model, self.model.guide, self.optimizer, loss=elbo) imps = ImportanceSampler(self.model.model, self.model.guide, num_samples=4) self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) loss = svi.evaluate_loss(observations=data) / data.shape[0] imps.sample(observations=data) log_likelihood = imps.get_log_likelihood().item() / data.shape[0] log_marginal = imps.get_log_normalizer().item() / data.shape[0] self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss) self.valid_metrics.update('log_likelihood', log_likelihood) self.valid_metrics.update('log_marginal', log_marginal) for met in self.metric_ftns: metric_val = met(self.model.model, self.model.guide, data, target, 4) self.valid_metrics.update(met.__name__, metric_val) if self.log_images: self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) return self.valid_metrics.result()
def test_nested_irange_in_elbo(self, n_steps=4000): pyro.clear_param_store() def model(): loc_latent = pyro.sample( "loc_latent", fakes.NonreparameterizedNormal(self.loc0, torch.pow(self.lam0, -0.5)).independent(1)) for i in pyro.irange("outer", self.n_outer): for j in pyro.irange("inner_%d" % i, self.n_inner): pyro.sample("obs_%d_%d" % (i, j), dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).independent(1), obs=self.data[i][j]) def guide(): loc_q = pyro.param( "loc_q", torch.tensor(self.analytic_loc_n.expand(2) + 0.234, requires_grad=True)) log_sig_q = pyro.param( "log_sig_q", torch.tensor(self.analytic_log_sig_n.expand(2) - 0.27, requires_grad=True)) sig_q = torch.exp(log_sig_q) pyro.sample( "loc_latent", fakes.NonreparameterizedNormal(loc_q, sig_q).independent(1), infer=dict(baseline=dict(use_decaying_avg_baseline=True))) for i in pyro.irange("outer", self.n_outer): for j in pyro.irange("inner_%d" % i, self.n_inner): pass guide_trace = pyro.poutine.trace(guide, graph_type="dense").get_trace() model_trace = pyro.poutine.trace(pyro.poutine.replay( model, trace=guide_trace), graph_type="dense").get_trace() assert len(model_trace.edges()) == 27 assert len(model_trace.nodes()) == 16 assert len(guide_trace.edges()) == 0 assert len(guide_trace.nodes()) == 9 adam = optim.Adam({"lr": 0.0008, "betas": (0.96, 0.999)}) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for k in range(n_steps): svi.step() loc_error = param_mse("loc_q", self.analytic_loc_n) log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n) if k % 500 == 0: logger.debug("loc error, log(scale) error: %.4f, %.4f" % (loc_error, log_sig_error)) assert_equal(0.0, loc_error, prec=0.04) assert_equal(0.0, log_sig_error, prec=0.04)
def main(**kwargs): args = argparse.Namespace(**kwargs) args.batch_size = 64 # WL: edited to fix a bug. ===== #pyro.set_rng_seed(args.seed) if args.seed is not None: pyro.set_rng_seed(args.seed) # ============================== X, true_counts = load_data() X_size = X.size(0) def per_param_optim_args(module_name, param_name): def isBaselineParam(module_name, param_name): return 'bl_' in module_name or 'bl_' in param_name lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name)\ else args.learning_rate return {'lr': lr} adam = optim.Adam(per_param_optim_args) elbo = TraceGraph_ELBO() svi = SVI(air.model, air.guide, adam, loss=elbo) # wy # WL: added for test. ===== print(args) times = [time.time()] print("\nstep\t" + "epoch\t" + "elbo\t" + "time(sec)", flush=True) #========================== t0 = time.time() for i in range(1, args.num_steps + 1): loss = svi.step(X) # wy if args.progress_every > 0 and i % args.progress_every == 0: # # WL: edited for test. ===== # print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format( # i, # (i * args.batch_size) / X_size, # (time.time() - t0) / 3600, # loss / X_size)) times.append(time.time()) print( f"{i:06d}\t" f"{(i * args.batch_size) / X_size:.3f}\t" f"{-loss / X_size:.4f}\t" f"{times[-1]-times[-2]:.3f}", flush=True) # =========================== if args.eval_every > 0 and i % args.eval_every == 0: acc, counts, error_z, error_ix = count_accuracy( X, true_counts, air, 1000) print('i={}, accuracy={}, counts={}'.format( i, acc, counts.numpy().tolist()))
def test_iarange_wrong_size_error(): def model(): p = torch.tensor(0.5) with pyro.iarange("iarange", 10, 5) as ind: pyro.sample("x", dist.Bernoulli(p).expand_by([1 + len(ind)])) def guide(): p = pyro.param("p", torch.tensor(0.5, requires_grad=True)) with pyro.iarange("iarange", 10, 5) as ind: pyro.sample("x", dist.Bernoulli(p).expand_by([1 + len(ind)])) assert_error(model, guide, TraceGraph_ELBO())
def test_dynamic_lr(scheduler): pyro.clear_param_store() def model(): sample = pyro.sample("latent", Normal(torch.tensor(0.0), torch.tensor(0.3))) return pyro.sample("obs", Normal(sample, torch.tensor(0.2)), obs=torch.tensor(0.1)) def guide(): loc = pyro.param("loc", torch.tensor(0.0)) scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive) pyro.sample("latent", Normal(loc, scale)) svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO()) for epoch in range(4): svi.step() svi.step() loc = pyro.param("loc").unconstrained() opt_loc = scheduler.optim_objs[loc].optimizer opt_scale = scheduler.optim_objs[loc].optimizer if issubclass( scheduler.pt_scheduler_constructor, torch.optim.lr_scheduler.ReduceLROnPlateau, ): scheduler.step(1.0) if epoch == 2: assert opt_loc.state_dict()["param_groups"][0]["lr"] == 0.1 assert opt_scale.state_dict()["param_groups"][0]["lr"] == 0.1 if epoch == 4: assert opt_loc.state_dict()["param_groups"][0]["lr"] == 0.01 assert opt_scale.state_dict()["param_groups"][0]["lr"] == 0.01 continue assert opt_loc.state_dict()["param_groups"][0]["initial_lr"] == 0.01 assert opt_scale.state_dict()["param_groups"][0]["initial_lr"] == 0.01 if epoch == 0: scheduler.step() assert opt_loc.state_dict()["param_groups"][0]["lr"] == 0.02 assert opt_scale.state_dict()["param_groups"][0]["lr"] == 0.02 assert abs(pyro.param("loc").item()) > 1e-5 assert abs(pyro.param("scale").item() - 0.5) > 1e-5 if epoch == 2: scheduler.step() assert opt_loc.state_dict()["param_groups"][0]["lr"] == 0.04 assert opt_scale.state_dict()["param_groups"][0]["lr"] == 0.04
def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ if self.jit: elbo = JitTraceGraph_ELBO(vectorize_particles=False, num_particles=self.num_particles) else: elbo = TraceGraph_ELBO(vectorize_particles=False, num_particles=self.num_particles) svi = SVI(self.model.model, self.model.guide, self.optimizer, loss=elbo) self.model.train() self.train_metrics.reset() for batch_idx, (data, target) in enumerate(self.data_loader): data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() loss = svi.step(observations=data) self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(target)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) self.writer.add_image( 'input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log
def do_elbo_test(self, reparameterized, n_steps, prec): logger.info( " - - - - - DO NORMALNORMAL ELBO TEST [reparameterized = %s] - - - - - " % reparameterized) pyro.clear_param_store() Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal def model(): with pyro.iarange("iarange", 2): loc_latent = pyro.sample( "loc_latent", Normal(self.loc0, torch.pow(self.lam0, -0.5))) for i, x in enumerate(self.data): pyro.sample("obs_%d" % i, dist.Normal(loc_latent, torch.pow(self.lam, -0.5)), obs=x) return loc_latent def guide(): loc_q = pyro.param( "loc_q", torch.tensor(self.analytic_loc_n.expand(2) + 0.334, requires_grad=True)) log_sig_q = pyro.param( "log_sig_q", torch.tensor(self.analytic_log_sig_n.expand(2) - 0.29, requires_grad=True)) sig_q = torch.exp(log_sig_q) with pyro.iarange("iarange", 2): loc_latent = pyro.sample("loc_latent", Normal(loc_q, sig_q)) return loc_latent adam = optim.Adam({"lr": .0015, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for k in range(n_steps): svi.step() loc_error = param_mse("loc_q", self.analytic_loc_n) log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n) if k % 250 == 0: logger.debug("loc error, log(scale) error: %.4f, %.4f" % (loc_error, log_sig_error)) assert_equal(0.0, loc_error, prec=prec) assert_equal(0.0, log_sig_error, prec=prec)
def on_epoch_start(self): total_epochs = self._n_epochs current_epoch = self._current_epoch if current_epoch > total_epochs * self.emb_train_epochs: if not self.start_finetune: pyro.clear_param_store() print('Switching to pyro') self.start_finetune = True self.optimizer = Adam({"lr": 1e-4}) self.svi = SVI(self.model_wrapper.pyro_model, self.model_wrapper.pyro_guide, self.optimizer, loss=TraceGraph_ELBO()) else: if self.optimizer is None: self.optimizer = torch.optim.Adam( self.model_wrapper.model.parameters())
def do_elbo_test(self, reparameterized, n_steps, beta1, lr): logger.info( " - - - - - DO EXPONENTIAL-GAMMA ELBO TEST [repa = %s] - - - - - " % reparameterized) pyro.clear_param_store() Gamma = dist.Gamma if reparameterized else fakes.NonreparameterizedGamma def model(): lambda_latent = pyro.sample("lambda_latent", Gamma(self.alpha0, self.beta0)) with pyro.iarange("data", len(self.data)): pyro.sample("obs", dist.Exponential(lambda_latent), obs=self.data) return lambda_latent def guide(): alpha_q_log = pyro.param( "alpha_q_log", torch.tensor(self.log_alpha_n + 0.17, requires_grad=True)) beta_q_log = pyro.param( "beta_q_log", torch.tensor(self.log_beta_n - 0.143, requires_grad=True)) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample( "lambda_latent", Gamma(alpha_q, beta_q), infer=dict(baseline=dict(use_decaying_avg_baseline=True))) with pyro.iarange("data", len(self.data)): pass adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)}) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for k in range(n_steps): svi.step() alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n) beta_error = param_abs_error("beta_q_log", self.log_beta_n) if k % 500 == 0: logger.debug("alpha_error, beta_error: %.4f, %.4f" % (alpha_error, beta_error)) assert_equal(0.0, alpha_error, prec=0.04) assert_equal(0.0, beta_error, prec=0.04)
def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ if self.jit: elbo = JitTraceGraph_ELBO(vectorize_particles=False, num_particles=self.num_particles) else: elbo = TraceGraph_ELBO(vectorize_particles=False, num_particles=self.num_particles) svi = SVI(self.model.model, self.model.guide, self.optimizer, loss=elbo) self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) loss = svi.evaluate_loss(observations=data) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(target)) self.writer.add_image( 'input', make_grid(data.cpu(), nrow=8, normalize=True)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result()
def do_elbo_test( self, repa1, repa2, n_steps, prec, lr, use_nn_baseline, use_decaying_avg_baseline, ): logger.info(" - - - - - DO NORMALNORMALNORMAL ELBO TEST - - - - - -") logger.info( "[reparameterized = %s, %s; nn_baseline = %s, decaying_baseline = %s]" % (repa1, repa2, use_nn_baseline, use_decaying_avg_baseline)) pyro.clear_param_store() Normal1 = dist.Normal if repa1 else fakes.NonreparameterizedNormal Normal2 = dist.Normal if repa2 else fakes.NonreparameterizedNormal if use_nn_baseline: class VanillaBaselineNN(nn.Module): def __init__(self, dim_input, dim_h): super().__init__() self.lin1 = nn.Linear(dim_input, dim_h) self.lin2 = nn.Linear(dim_h, 2) self.sigmoid = nn.Sigmoid() def forward(self, x): h = self.sigmoid(self.lin1(x)) return self.lin2(h) loc_prime_baseline = pyro.module("loc_prime_baseline", VanillaBaselineNN(2, 5)) else: loc_prime_baseline = None def model(): with pyro.plate("plate", 2): loc_latent_prime = pyro.sample( "loc_latent_prime", Normal1(self.loc0, torch.pow(self.lam0, -0.5))) loc_latent = pyro.sample( "loc_latent", Normal2(loc_latent_prime, torch.pow(self.lam0, -0.5))) with pyro.plate("data", len(self.data)): pyro.sample( "obs", dist.Normal(loc_latent, torch.pow( self.lam, -0.5)).expand_by(self.data.shape[:1]), obs=self.data, ) return loc_latent # note that the exact posterior is not mean field! def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.334) log_sig_q = pyro.param("log_sig_q", self.analytic_log_sig_n.expand(2) - 0.29) loc_q_prime = pyro.param("loc_q_prime", torch.tensor([-0.34, 0.52])) kappa_q = pyro.param("kappa_q", torch.tensor([0.74])) log_sig_q_prime = pyro.param("log_sig_q_prime", -0.5 * torch.log(1.2 * self.lam0)) sig_q, sig_q_prime = torch.exp(log_sig_q), torch.exp( log_sig_q_prime) with pyro.plate("plate", 2): loc_latent = pyro.sample( "loc_latent", Normal2(loc_q, sig_q), infer=dict(baseline=dict( use_decaying_avg_baseline=use_decaying_avg_baseline)), ) pyro.sample( "loc_latent_prime", Normal1( kappa_q.expand_as(loc_latent) * loc_latent + loc_q_prime, sig_q_prime, ), infer=dict(baseline=dict( nn_baseline=loc_prime_baseline, nn_baseline_input=loc_latent, use_decaying_avg_baseline=use_decaying_avg_baseline, )), ) with pyro.plate("data", len(self.data)): pass return loc_latent adam = optim.Adam({"lr": 0.0015, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for k in range(n_steps): svi.step() loc_error = param_mse("loc_q", self.analytic_loc_n) log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n) loc_prime_error = param_mse("loc_q_prime", 0.5 * self.loc0) kappa_error = param_mse("kappa_q", 0.5 * torch.ones(1)) log_sig_prime_error = param_mse("log_sig_q_prime", -0.5 * torch.log(2.0 * self.lam0)) if k % 500 == 0: logger.debug("errors: %.4f, %.4f" % (loc_error, log_sig_error)) logger.debug(", %.4f, %.4f" % (loc_prime_error, log_sig_prime_error)) logger.debug(", %.4f" % kappa_error) assert_equal(0.0, loc_error, prec=prec) assert_equal(0.0, log_sig_error, prec=prec) assert_equal(0.0, loc_prime_error, prec=prec) assert_equal(0.0, log_sig_prime_error, prec=prec) assert_equal(0.0, kappa_error, prec=prec)
global init_state init_state = reset_init_state() survive = guide(500) results.append(survive) self.echo = False if np.mean(results) > imme_time * 0.9 and imme_time < MAXTIME: imme_time = imme_time * 2 print("update training max_time to", imme_time) agent = AgentModel() guide = agent.guide model = agent.model learning_rate = 2e-5 optimizer = optim.Adam({"lr": learning_rate}) svi = SVI(model, guide, optimizer, loss=TraceGraph_ELBO()) def optimize(): global imme_time loss = 0 print("Optimizing...") for t in range(num_steps): global init_state init_state = reset_init_state() loss += svi.step(imme_time) if (t % 1000 == 0) and (t > 0): print("at {} step loss is {}".format(t, loss / t)) def train(epoch=2, batch_size=10):
def do_elbo_test(self, reparameterized, n_steps, lr, prec, beta1, difficulty=1.0, model_permutation=False): n_repa_nodes = torch.sum(self.which_nodes_reparam) if not reparameterized \ else len(self.q_topo_sort) logger.info(( " - - - DO GAUSSIAN %d-LAYERED PYRAMID ELBO TEST " + "(with a total of %d RVs) [reparameterized=%s; %d/%d; perm=%s] - - -" ) % (self.N, (2**self.N) - 1, reparameterized, n_repa_nodes, len(self.q_topo_sort), model_permutation)) pyro.clear_param_store() # check graph structure is as expected but only for N=2 if self.N == 2: guide_trace = pyro.poutine.trace( self.guide, graph_type="dense").get_trace( reparameterized=reparameterized, model_permutation=model_permutation, difficulty=difficulty) expected_nodes = set([ 'log_sig_1R', 'kappa_1_1L', '_INPUT', 'constant_term_loc_latent_1R', '_RETURN', 'loc_latent_1R', 'loc_latent_1', 'constant_term_loc_latent_1', 'loc_latent_1L', 'constant_term_loc_latent_1L', 'log_sig_1L', 'kappa_1_1R', 'kappa_1R_1L', 'log_sig_1' ]) expected_edges = set([('loc_latent_1R', 'loc_latent_1'), ('loc_latent_1L', 'loc_latent_1R'), ('loc_latent_1L', 'loc_latent_1')]) assert expected_nodes == set(guide_trace.nodes) assert expected_edges == set(guide_trace.edges) adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)}) svi = SVI(self.model, self.guide, adam, loss=TraceGraph_ELBO()) for step in range(n_steps): t0 = time.time() svi.step(reparameterized=reparameterized, model_permutation=model_permutation, difficulty=difficulty) if step % 50 == 0 or step == n_steps - 1: log_sig_errors = [] for node in self.target_lambdas: target_log_sig = -0.5 * torch.log( self.target_lambdas[node]) log_sig_error = param_mse('log_sig_' + node, target_log_sig) log_sig_errors.append(log_sig_error) max_log_sig_error = np.max(log_sig_errors) min_log_sig_error = np.min(log_sig_errors) mean_log_sig_error = np.mean(log_sig_errors) leftmost_node = self.q_topo_sort[0] leftmost_constant_error = param_mse( 'constant_term_' + leftmost_node, self.target_leftmost_constant) almost_leftmost_constant_error = param_mse( 'constant_term_' + leftmost_node[:-1] + 'R', self.target_almost_leftmost_constant) logger.debug( "[mean function constant errors (partial)] %.4f %.4f" % (leftmost_constant_error, almost_leftmost_constant_error)) logger.debug( "[min/mean/max log(scale) errors] %.4f %.4f %.4f" % (min_log_sig_error, mean_log_sig_error, max_log_sig_error)) logger.debug("[step time = %.3f; N = %d; step = %d]\n" % (time.time() - t0, self.N, step)) assert_equal(0.0, max_log_sig_error, prec=prec) assert_equal(0.0, leftmost_constant_error, prec=prec) assert_equal(0.0, almost_leftmost_constant_error, prec=prec)
def do_elbo_test(self, reparameterized, n_steps, lr, prec, difficulty=1.0): n_repa_nodes = torch.sum( self.which_nodes_reparam) if not reparameterized else self.N logger.info( " - - - - - DO GAUSSIAN %d-CHAIN ELBO TEST [reparameterized = %s; %d/%d] - - - - - " % (self.N, reparameterized, n_repa_nodes, self.N)) if self.N < 0: def array_to_string(y): return str( map(lambda x: "%.3f" % x.detach().cpu().numpy()[0], y)) logger.debug("lambdas: " + array_to_string(self.lambdas)) logger.debug("target_mus: " + array_to_string(self.target_mus[1:])) logger.debug("target_kappas: "******"lambda_posts: " + array_to_string(self.lambda_posts[1:])) logger.debug("lambda_tilde_posts: " + array_to_string(self.lambda_tilde_posts)) pyro.clear_param_store() adam = optim.Adam({"lr": lr, "betas": (0.95, 0.999)}) elbo = TraceGraph_ELBO() loss_and_grads = elbo.loss_and_grads # loss_and_grads = elbo.jit_loss_and_grads # This fails. svi = SVI(self.model, self.guide, adam, loss=elbo.loss, loss_and_grads=loss_and_grads) for step in range(n_steps): t0 = time.time() svi.step(reparameterized=reparameterized, difficulty=difficulty) if step % 5000 == 0 or step == n_steps - 1: kappa_errors, log_sig_errors, loc_errors = [], [], [] for k in range(1, self.N + 1): if k != self.N: kappa_error = param_mse("kappa_q_%d" % k, self.target_kappas[k]) kappa_errors.append(kappa_error) loc_errors.append( param_mse("loc_q_%d" % k, self.target_mus[k])) log_sig_error = param_mse( "log_sig_q_%d" % k, -0.5 * torch.log(self.lambda_posts[k])) log_sig_errors.append(log_sig_error) max_errors = (np.max(loc_errors), np.max(log_sig_errors), np.max(kappa_errors)) min_errors = (np.min(loc_errors), np.min(log_sig_errors), np.min(kappa_errors)) mean_errors = (np.mean(loc_errors), np.mean(log_sig_errors), np.mean(kappa_errors)) logger.debug( "[max errors] (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)" % max_errors) logger.debug( "[min errors] (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)" % min_errors) logger.debug( "[mean errors] (loc, log_scale, kappa) = (%.4f, %.4f, %.4f)" % mean_errors) logger.debug("[step time = %.3f; N = %d; step = %d]\n" % (time.time() - t0, self.N, step)) assert_equal(0.0, max_errors[0], prec=prec) assert_equal(0.0, max_errors[1], prec=prec) assert_equal(0.0, max_errors[2], prec=prec)
def main(**kwargs): args = argparse.Namespace(**kwargs) if 'save' in args: if os.path.exists(args.save): raise RuntimeError('Output file "{}" already exists.'.format( args.save)) if args.seed is not None: pyro.set_rng_seed(args.seed) X, true_counts = load_data() X_size = X.size(0) if args.cuda: X = X.cuda() # Build a function to compute z_pres prior probabilities. if args.z_pres_prior_raw: def base_z_pres_prior_p(t): return args.z_pres_prior else: base_z_pres_prior_p = make_prior(args.z_pres_prior) # Wrap with logic to apply any annealing. def z_pres_prior_p(opt_step, time_step): p = base_z_pres_prior_p(time_step) if args.anneal_prior == 'none': return p else: decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior] return decay(p, args.anneal_prior_to, args.anneal_prior_begin, args.anneal_prior_duration, opt_step) model_arg_keys = [ 'window_size', 'rnn_hidden_size', 'decoder_output_bias', 'decoder_output_use_sigmoid', 'baseline_scalar', 'encoder_net', 'decoder_net', 'predict_net', 'embed_net', 'bl_predict_net', 'non_linearity', 'pos_prior_mean', 'pos_prior_sd', 'scale_prior_mean', 'scale_prior_sd' ] model_args = { key: getattr(args, key) for key in model_arg_keys if key in args } air = AIR(num_steps=args.model_steps, x_size=50, use_masking=not args.no_masking, use_baselines=not args.no_baselines, z_what_size=args.encoder_latent_size, use_cuda=args.cuda, **model_args) if args.verbose: print(air) print(args) if 'load' in args: print('Loading parameters...') air.load_state_dict(torch.load(args.load)) # Viz sample from prior. if args.viz: vis = visdom.Visdom(env=args.visdom_env) z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0)) vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z)))) def isBaselineParam(module_name, param_name): return 'bl_' in module_name or 'bl_' in param_name def per_param_optim_args(module_name, param_name): lr = args.baseline_learning_rate if isBaselineParam( module_name, param_name) else args.learning_rate return {'lr': lr} adam = optim.Adam(per_param_optim_args) elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO() svi = SVI(air.model, air.guide, adam, loss=elbo) # Do inference. t0 = time.time() examples_to_viz = X[5:10] for i in range(1, args.num_steps + 1): loss = svi.step(X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i)) if args.progress_every > 0 and i % args.progress_every == 0: print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format( i, (i * args.batch_size) / X_size, (time.time() - t0) / 3600, loss / X_size)) if args.viz and i % args.viz_every == 0: trace = poutine.trace(air.guide).get_trace(examples_to_viz, None) z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0)) z_wheres = tensor_to_objs(latents_to_tensor(z)) # Show data with inferred objection positions. vis.images(draw_many(examples_to_viz, z_wheres)) # Show reconstructions of data. vis.images(draw_many(recons, z_wheres)) if args.eval_every > 0 and i % args.eval_every == 0: # Measure accuracy on subset of training data. acc, counts, error_z, error_ix = count_accuracy( X, true_counts, air, 1000) print('i={}, accuracy={}, counts={}'.format( i, acc, counts.numpy().tolist())) if args.viz and error_ix.size(0) > 0: vis.images(draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])), opts=dict(caption='errors ({})'.format(i))) if 'save' in args and i % args.save_every == 0: print('Saving parameters...') torch.save(air.state_dict(), args.save)
def embed(data, embed_args_dict, sample_args_dict): r"""Fits a combined matrix factorization and linear regression model on the observed network data. Args: data: This is expected to be the data object returned from bitcoin.BitcoinOTC(), and this script probably won't work otherwise. embed_args_dict (dict): A dictionary which is expected to contain the following keys: embed_dim (int): An integer specifying the embedding dimension. omega_model_scale (float): A positive float which specifies the prior variance on the embedding vectors. obs_scale (float): A positive float which specifies the variance for the observed logit scaled ratings. krein (bool): Specifies whether to use a Krein style inner product (a difference of two inner products; krein = True) or a regular inner product (krein = False) between embedding vectors for the matrix factorization part of the model. This implicitly assumes that the embedding dimension is even; if not, then an error will likely pop up somewhere. learning_rate (float): Learning rate for the ADAM optimizer. num_iters (int): Number of iterations to perform SVI. logging (bool): If True, then ELBO updates and the time ellapsed are output every 500 iterations, unless num_iters <= 1000 in which case it is set to 100. sample_args_dict (dict): A dictionary with handles the subsampling procedure used on the data object. Check the docstring for the data object for the relevant details. Returns: logits (): The estimated logit probability of belonging to the truthfull class """ # Extract keys from embed_args_dict embed_dim = embed_args_dict['embed_dim'] omega_model_scale = embed_args_dict['omega_model_scale'] obs_scale = embed_args_dict['obs_scale'] learning_rate = embed_args_dict['learning_rate'] num_iters = embed_args_dict['num_iters'] logging_ind = embed_args_dict['logging'] # Set logging printing rate if (num_iters > 1000): log_update = 500 else: log_update = 100 # Make sure logging actually works logging.basicConfig(format='%(message)s', level=logging.INFO) # Define guide object for embedding model def guide(data, node_ind, edge_ind, edge_list): r"""Defines a variational family to use to fit an approximate posterior distribution for the probability model defined in model.""" # Deleting arguments not used in the guide for linting purposes del edge_ind, edge_list # Parameters governing the priors on the embedding vectors # omega_loc should have shape [embed_dum, data.num_nodes] omega_loc = pyro.param( 'omega_loc', lambda: torch.randn(embed_dim, data.num_nodes) / np. sqrt(embed_dim)) # omega_scale should be a single positive tensor omega_scale = pyro.param('omega_scale', torch.tensor(1.0), constraint=constraints.positive) # Paramaeters governing the prior fr the linear regression # beta_loc should be of shape [embed_dim] beta_loc = pyro.param('beta_loc', 0.5 * torch.randn(embed_dim)) # beta_scale should be a single positive tensor beta_scale = pyro.param('beta_scale', torch.tensor(1.0), constraint=constraints.positive) # mu_loc should be a single tensor mu_loc = pyro.param('mu_loc', torch.tensor([0.0])) # mu_scale should be a single positive tensor mu_scale = pyro.param('mu_scale', torch.tensor(1.0), constraint=constraints.positive) # Sample the coefficient vector and intercept for linear regression beta = pyro.sample( 'beta', dist.Normal(loc=beta_loc, scale=beta_scale * torch.ones(embed_dim)).to_event(1)) mu = pyro.sample('mu', dist.Normal(mu_loc, mu_scale).to_event(1)) # Handle the subsampling of the embedding vectors with poutine.scale(scale=data.num_nodes / len(node_ind)): omega = pyro.sample( 'omega', dist.Normal(loc=omega_loc[:, node_ind], scale=omega_scale).to_event(2)) return beta, mu, omega # Defines the model to use for SVI when using the usual inner product def model_ip(data, node_ind, edge_ind, edge_list): r"""Defines a probabilistic model for the observed network data.""" # Define priors on the regression coefficients mu = pyro.sample( 'mu', dist.Normal(torch.tensor([0.0]), torch.tensor([2.0])).to_event(1)) beta = pyro.sample( 'beta', dist.Normal(loc=torch.zeros(embed_dim), scale=torch.tensor(2.0)).to_event(1)) # Define prior on the embedding vectors, with subsampling with poutine.scale(scale=data.num_nodes / len(node_ind)): omega = pyro.sample( 'omega', dist.Normal(loc=torch.zeros(embed_dim, len(node_ind)), scale=omega_model_scale).to_event(2)) # Before proceeding further, define a list t which acts as the # inverse function of node_ind - i.e it takes a number in node_ind # to its index location t = torch.zeros(node_ind.max() + 1, dtype=torch.long) t[node_ind] = torch.arange(len(node_ind)) # Create mask corresponding to entries of ind which lie within the # training set (i.e data.train_nodes) gt_data = data.gt[node_ind] obs_mask = np.isin(node_ind, data.nodes_train).tolist() gt_data[gt_data != gt_data] = 0.0 obs_mask = torch.tensor(obs_mask, dtype=torch.bool) # Compute logits, compute relevant parts of sample if sum(obs_mask) != 0: logit_prob = mu + torch.mv(omega.t(), beta) with poutine.scale(scale=len(data.nodes_train) / sum(obs_mask)): pyro.sample( 'trust', dist.Bernoulli(logits=logit_prob[obs_mask]).independent(1), obs=gt_data[obs_mask]) # Begin extracting the relevant components of the gram matrix # formed by omega. Note that to extract the relevant indices, # we need to account for the change in indexing induced by # subsampling omega gram = torch.mm(omega.t(), omega) gram_sample = gram[t[edge_list[0, :]], t[edge_list[0, :]]] # Finally draw terms corresponding to the edges with poutine.scale(scale=data.num_edges / len(edge_ind)): pyro.sample('a', dist.Normal(loc=gram_sample, scale=obs_scale).to_event(1), obs=data.edge_weight_logit[edge_ind]) # Defines the model to use for SVI when using the usual inner product def model_krein(data, node_ind, edge_ind, edge_list): r"""Defines a probabilistic model for the observed network data.""" # Define priors on the regression coefficients mu = pyro.sample( 'mu', dist.Normal(torch.tensor([0.0]), torch.tensor([2.0])).to_event(1)) beta = pyro.sample( 'beta', dist.Normal(loc=torch.zeros(embed_dim), scale=torch.tensor(2.0)).to_event(1)) # Define prior on the embedding vectors, with subsampling with poutine.scale(scale=data.num_nodes / len(node_ind)): omega = pyro.sample( 'omega', dist.Normal(loc=torch.zeros(embed_dim, len(node_ind)), scale=omega_model_scale).to_event(2)) # Before proceeding further, define a list t which acts as the # inverse function of node_ind - i.e it takes a number in node_ind # to its index location t = torch.zeros(node_ind.max() + 1, dtype=torch.long) t[node_ind] = torch.arange(len(node_ind)) # Create mask corresponding to entries of ind which lie within the # training set (i.e data.train_nodes) gt_data = data.gt[node_ind] obs_mask = np.isin(node_ind, data.nodes_train).tolist() gt_data[gt_data != gt_data] = 0.0 obs_mask = torch.tensor(obs_mask, dtype=torch.bool) # Compute logits, compute relevant parts of sample if sum(obs_mask) != 0: logit_prob = mu + torch.mv(omega.t(), beta) with poutine.scale(scale=len(data.nodes_train) / sum(obs_mask)): pyro.sample( 'trust', dist.Bernoulli(logits=logit_prob[obs_mask]).independent(1), obs=gt_data[obs_mask]) # Begin extracting the relevant components of the gram matrix # formed by omega. Note that to extract the relevant indices, # we need to account for the change in indexing induced by # subsampling omega gram_pos = torch.mm(omega[:int(embed_dim / 2), :].t(), omega[:int(embed_dim / 2), :]) gram_neg = torch.mm(omega[int(embed_dim / 2):, :].t(), omega[int(embed_dim / 2):, :]) gram = gram_pos - gram_neg gram_sample = gram[t[edge_list[0, :]], t[edge_list[0, :]]] # Finally draw terms corresponding to the edges with poutine.scale(scale=data.num_edges / len(edge_ind)): pyro.sample('a', dist.Normal(loc=gram_sample, scale=obs_scale).to_event(1), obs=data.edge_weight_logit[edge_ind]) # Define SVI object depending on if we're using a positive definite # bilinear form on embedding vectors or the Krein inner product if embed_args_dict['krein']: svi = SVI(model_krein, guide, optim.Adam({"lr": learning_rate}), loss=TraceGraph_ELBO()) else: svi = SVI(model_ip, guide, optim.Adam({"lr": learning_rate}), loss=TraceGraph_ELBO()) # Begin optimization # Keep track of time/optizing if desired if logging_ind: time_store = [] t0 = time.time() elbo = [] pyro.clear_param_store() for i in range(num_iters): # Really bad error handling for when the subsampling code for the # random walk decides to break count = 0 while (count < 20): try: subsample_dict = data.subsample(**sample_args_dict) count = 30 except IndexError: count += 1 elbo_val = svi.step(data, **subsample_dict) if logging_ind & (i % log_update == 0) & (i > 0): elbo.append(elbo_val) t1 = time.time() time_store.append(t1 - t0) logging.info('Elbo loss: {}'.format(elbo_val)) logging.info('Expected completion time: {}s'.format( int(np.average(time_store) * (num_iters - i) / log_update))) t0 = time.time() # Extract the variational parameters and return them vp_dict = {} vp_dict['mu_loc'] = pyro.param('mu_loc') vp_dict['beta_loc'] = pyro.param('beta_loc') vp_dict['omega_loc'] = pyro.param('omega_loc') vp_dict['mu_scale'] = pyro.param('mu_scale') vp_dict['beta_scale'] = pyro.param('beta_scale') vp_dict['omega_scale'] = pyro.param('omega_scale') if 'elbo' in locals(): return vp_dict, elbo else: return vp_dict
def do_inference(use_baseline=True): def guide(data): a = pyro.param("alpha_q", torch.tensor(15.0), constraint=constraints.positive) b = pyro.param("beta_q", torch.tensor(15.0), constraint=constraints.positive) baseline = dict(baseline={ 'use_decaying_avg_baseline': use_baseline, 'baseline_beta': 0.9 }) pyro.sample("z_fairness", NonreparameterizedBeta(a, b), infer=baseline) # Inference pyro.clear_param_store() adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)} optimizer = Adam(adam_params) alpha0, beta0 = 10.0, 10.0 svi = SVI(model(alpha0, beta0), guide, optimizer, loss=TraceGraph_ELBO()) NPos = 10 NNeg = 5 data = torch.tensor(NPos * [1.0] + NNeg * [0.0]) def param_abs_error(name, target): return torch.sum(torch.abs(target - pyro.param(name))).item() # True parameters true_alpha = data.sum() + alpha0 true_beta = len(data) - data.sum() + beta0 # Run n_steps = 10000 for step in tqdm(range(n_steps)): svi.step(data) # compute the distance to the parameters of the true posterior alpha_error = param_abs_error("alpha_q", true_alpha) beta_error = param_abs_error("beta_q", true_beta) # stop inference early if we're close to the true posterior if alpha_error < 0.8 and beta_error < 0.8: break alpha_q = pyro.param("alpha_q").item() beta_q = pyro.param("beta_q").item() inferred_mean = alpha_q / (alpha_q + beta_q) factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q)) inferred_std = inferred_mean * math.sqrt(factor) print("Parameters after {} steps: {} {} {} {}".format( step, true_alpha, true_beta, alpha_q, beta_q)) print("\nbased on the data and our prior belief, the fairness " + "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std)) print("True posterior based on counting real and pseudoflips: {}".format( (NPos + 10.0) / (NPos + NNeg + 20.0)))
def _test_plate_in_elbo(self, n_superfluous_top, n_superfluous_bottom, n_steps, lr=0.0012): pyro.clear_param_store() self.data_tensor = torch.zeros(9, 2) for _out in range(self.n_outer): for _in in range(self.n_inner): self.data_tensor[3 * _out + _in, :] = self.data[_out][_in] self.data_as_list = [ self.data_tensor[0:4, :], self.data_tensor[4:7, :], self.data_tensor[7:9, :], ] def model(): loc_latent = pyro.sample( "loc_latent", fakes.NonreparameterizedNormal(self.loc0, torch.pow(self.lam0, -0.5)).to_event(1), ) for i in pyro.plate("outer", 3): x_i = self.data_as_list[i] with pyro.plate("inner_%d" % i, x_i.size(0)): for k in range(n_superfluous_top): z_i_k = pyro.sample( "z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(0, 1).expand_by( [4 - i]), ) assert z_i_k.shape == (4 - i, ) obs_i = pyro.sample( "obs_%d" % i, dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).to_event(1), obs=x_i, ) assert obs_i.shape == (4 - i, 2) for k in range(n_superfluous_top, n_superfluous_top + n_superfluous_bottom): z_i_k = pyro.sample( "z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(0, 1).expand_by( [4 - i]), ) assert z_i_k.shape == (4 - i, ) pt_loc_baseline = torch.nn.Linear(1, 1) pt_superfluous_baselines = [] for k in range(n_superfluous_top + n_superfluous_bottom): pt_superfluous_baselines.extend([ torch.nn.Linear(2, 4), torch.nn.Linear(2, 3), torch.nn.Linear(2, 2) ]) def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.094) log_sig_q = pyro.param("log_sig_q", self.analytic_log_sig_n.expand(2) - 0.07) sig_q = torch.exp(log_sig_q) trivial_baseline = pyro.module("loc_baseline", pt_loc_baseline) baseline_value = trivial_baseline(torch.ones(1)).squeeze() loc_latent = pyro.sample( "loc_latent", fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1), infer=dict(baseline=dict(baseline_value=baseline_value)), ) for i in pyro.plate("outer", 3): with pyro.plate("inner_%d" % i, 4 - i): for k in range(n_superfluous_top + n_superfluous_bottom): z_baseline = pyro.module( "z_baseline_%d_%d" % (i, k), pt_superfluous_baselines[3 * k + i], ) baseline_value = z_baseline(loc_latent.detach()) mean_i = pyro.param("mean_%d_%d" % (i, k), 0.5 * torch.ones(4 - i)) z_i_k = pyro.sample( "z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(mean_i, 1), infer=dict(baseline=dict( baseline_value=baseline_value)), ) assert z_i_k.shape == (4 - i, ) def per_param_callable(param_name): if "baseline" in param_name: return {"lr": 0.010, "betas": (0.95, 0.999)} else: return {"lr": lr, "betas": (0.95, 0.999)} adam = optim.Adam(per_param_callable) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for step in range(n_steps): svi.step() loc_error = param_abs_error("loc_q", self.analytic_loc_n) log_sig_error = param_abs_error("log_sig_q", self.analytic_log_sig_n) if n_superfluous_top > 0 or n_superfluous_bottom > 0: superfluous_errors = [] for k in range(n_superfluous_top + n_superfluous_bottom): mean_0_error = torch.sum( torch.pow(pyro.param("mean_0_%d" % k), 2.0)) mean_1_error = torch.sum( torch.pow(pyro.param("mean_1_%d" % k), 2.0)) mean_2_error = torch.sum( torch.pow(pyro.param("mean_2_%d" % k), 2.0)) superfluous_error = torch.max( torch.max(mean_0_error, mean_1_error), mean_2_error) superfluous_errors.append( superfluous_error.detach().cpu().numpy()) if step % 500 == 0: logger.debug("loc error, log(scale) error: %.4f, %.4f" % (loc_error, log_sig_error)) if n_superfluous_top > 0 or n_superfluous_bottom > 0: logger.debug("superfluous error: %.4f" % np.max(superfluous_errors)) assert_equal(0.0, loc_error, prec=0.04) assert_equal(0.0, log_sig_error, prec=0.05) if n_superfluous_top > 0 or n_superfluous_bottom > 0: assert_equal(0.0, np.max(superfluous_errors), prec=0.04)
def test_elbo_mapdata(map_type, batch_size, n_steps, lr): # normal-normal: known covariance lam0 = torch.tensor([0.1, 0.1]) # precision of prior loc0 = torch.tensor([0.0, 0.5]) # prior mean # known precision of observation noise lam = torch.tensor([6.0, 4.0]) data = [] sum_data = torch.zeros(2) def add_data_point(x, y): data.append(torch.tensor([x, y])) sum_data.data.add_(data[-1].data) add_data_point(0.1, 0.21) add_data_point(0.16, 0.11) add_data_point(0.06, 0.31) add_data_point(-0.01, 0.07) add_data_point(0.23, 0.25) add_data_point(0.19, 0.18) add_data_point(0.09, 0.41) add_data_point(-0.04, 0.17) data = torch.stack(data) n_data = torch.tensor([float(len(data))]) analytic_lam_n = lam0 + n_data.expand_as(lam) * lam analytic_log_sig_n = -0.5 * torch.log(analytic_lam_n) analytic_loc_n = sum_data * (lam / analytic_lam_n) + loc0 * ( lam0 / analytic_lam_n) logger.debug("DOING ELBO TEST [bs = {}, map_type = {}]".format( batch_size, map_type)) pyro.clear_param_store() def model(): loc_latent = pyro.sample( "loc_latent", dist.Normal(loc0, torch.pow(lam0, -0.5)).to_event(1)) if map_type == "iplate": for i in pyro.plate("aaa", len(data), batch_size): pyro.sample( "obs_%d" % i, dist.Normal(loc_latent, torch.pow(lam, -0.5)).to_event(1), obs=data[i], ), elif map_type == "plate": with pyro.plate("aaa", len(data), batch_size) as ind: pyro.sample( "obs", dist.Normal(loc_latent, torch.pow(lam, -0.5)).to_event(1), obs=data[ind], ), else: for i, x in enumerate(data): pyro.sample( "obs_%d" % i, dist.Normal(loc_latent, torch.pow(lam, -0.5)).to_event(1), obs=x, ) return loc_latent def guide(): loc_q = pyro.param( "loc_q", analytic_loc_n.detach().clone() + torch.tensor([-0.18, 0.23])) log_sig_q = pyro.param( "log_sig_q", analytic_log_sig_n.detach().clone() - torch.tensor([-0.18, 0.23]), ) sig_q = torch.exp(log_sig_q) pyro.sample("loc_latent", dist.Normal(loc_q, sig_q).to_event(1)) if map_type == "iplate" or map_type is None: for i in pyro.plate("aaa", len(data), batch_size): pass elif map_type == "plate": # dummy plate to do subsampling for observe with pyro.plate("aaa", len(data), batch_size): pass else: pass adam = optim.Adam({"lr": lr}) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for k in range(n_steps): svi.step() loc_error = torch.sum( torch.pow(analytic_loc_n - pyro.param("loc_q"), 2.0)) log_sig_error = torch.sum( torch.pow(analytic_log_sig_n - pyro.param("log_sig_q"), 2.0)) if k % 500 == 0: logger.debug("errors - {}, {}".format(loc_error, log_sig_error)) assert_equal(loc_error.item(), 0, prec=0.05) assert_equal(log_sig_error.item(), 0, prec=0.06)
print('15') def isBaselineParam(module_name, param_name): return 'bl_' in module_name or 'bl_' in param_name print('16') def per_param_optim_args(module_name, param_name): lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name) else args.learning_rate return {'lr': lr} print('17') adam = optim.Adam(per_param_optim_args) elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO() svi = SVI(air.model, air.guide, adam, loss=elbo) print('18') t0 = time.time() examples_to_viz = X[5:10] print('19') print('Epochs starting') for i in range(1, args.num_steps + 1): print("Epoch: {}".format(i)) loss = svi.step(X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i)) if args.progress_every > 0 and i % args.progress_every == 0: print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(