def do_test_per_param_optim(self, fixed_param, free_param): pyro.clear_param_store() def model(): prior_dist = Normal(self.mu0, torch.pow(self.lam0, -0.5)) mu_latent = pyro.sample("mu_latent", prior_dist) x_dist = Normal(mu_latent, torch.pow(self.lam, -0.5)) pyro.observe("obs", x_dist, self.data) return mu_latent def guide(): mu_q = pyro.param( "mu_q", Variable( torch.zeros(1), requires_grad=True)) log_sig_q = pyro.param( "log_sig_q", Variable( torch.zeros(1), requires_grad=True)) sig_q = torch.exp(log_sig_q) pyro.sample("mu_latent", Normal(mu_q, sig_q)) def optim_params(module_name, param_name, tags): if param_name == fixed_param: return {'lr': 0.00} elif param_name == free_param: return {'lr': 0.01} adam = optim.Adam(optim_params) adam2 = optim.Adam(optim_params) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True) svi2 = SVI(model, guide, adam2, loss="ELBO", trace_graph=True) svi.step() adam_initial_step_count = list(adam.get_state()['mu_q']['state'].items())[0][1]['step'] adam.save('adam.unittest.save') svi.step() adam_final_step_count = list(adam.get_state()['mu_q']['state'].items())[0][1]['step'] adam2.load('adam.unittest.save') svi2.step() adam2_step_count_after_load_and_step = list(adam2.get_state()['mu_q']['state'].items())[0][1]['step'] assert adam_initial_step_count == 1 assert adam_final_step_count == 2 assert adam2_step_count_after_load_and_step == 2 free_param_unchanged = torch.equal(pyro.param(free_param).data, torch.zeros(1)) fixed_param_unchanged = torch.equal(pyro.param(fixed_param).data, torch.zeros(1)) assert fixed_param_unchanged and not free_param_unchanged
def main(**kwargs): args = argparse.Namespace(**kwargs) args.batch_size = 64 pyro.set_rng_seed(args.seed) X, true_counts = load_data() X_size = X.size(0) def per_param_optim_args(module_name, param_name): def isBaselineParam(module_name, param_name): return 'bl_' in module_name or 'bl_' in param_name lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name)\ else args.learning_rate return {'lr': lr} adam = optim.Adam(per_param_optim_args) elbo = TraceGraph_ELBO() svi = SVI(air.model, air.guide, adam, loss=elbo) # wy t0 = time.time() for i in range(1, args.num_steps + 1): loss = svi.step(X) # wy if args.progress_every > 0 and i % args.progress_every == 0: print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format( i, (i * args.batch_size) / X_size, (time.time() - t0) / 3600, loss / X_size)) if args.eval_every > 0 and i % args.eval_every == 0: acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000) print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist()))
def test_deterministic(with_plate, event_shape): def model(y=None): with pyro.util.optional(pyro.plate("plate", 3), with_plate): x = pyro.sample("x", dist.Normal(0, 1).expand(event_shape).to_event()) x2 = pyro.deterministic("x2", x**2, event_dim=len(event_shape)) pyro.deterministic("x3", x2) return pyro.sample("obs", dist.Normal(x2, 0.1).to_event(), obs=y) y = torch.tensor(4.0) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO()) for i in range(100): svi.step(y) actual = Predictive( model, guide=guide, return_sites=["x2", "x3"], num_samples=1000 )() x2_batch_shape = (3,) if with_plate else () assert actual["x2"].shape == (1000,) + x2_batch_shape + event_shape # x3 shape is prepended 1 to match Pyro shape semantics x3_batch_shape = (1, 3) if with_plate else () assert actual["x3"].shape == (1000,) + x3_batch_shape + event_shape assert_close(actual["x2"].mean(), y, rtol=0.1) assert_close(actual["x3"].mean(), y, rtol=0.1)
def test_onehot_svi_usage(): def model(): p = torch.tensor([0.25] * 4) pyro.sample("z", OneHotCategorical(probs=p)) def guide(): q = pyro.param("q", torch.tensor([0.1, 0.2, 0.3, 0.4]), constraint=constraints.simplex) temp = torch.tensor(0.10) pyro.sample( "z", RelaxedOneHotCategoricalStraightThrough(temperature=temp, probs=q)) adam = optim.Adam({"lr": 0.001, "betas": (0.95, 0.999)}) svi = SVI(model, guide, adam, loss=Trace_ELBO()) for k in range(6000): svi.step() assert_equal( pyro.param("q"), torch.tensor([0.25] * 4), prec=0.01, msg="test svi usage of RelaxedOneHotCategoricalStraightThrough failed", )
def do_elbo_test(self, reparameterized, n_steps, loss): pyro.clear_param_store() Beta = dist.Beta if reparameterized else fakes.NonreparameterizedBeta def model(): p_latent = pyro.sample("p_latent", Beta(self.alpha0, self.beta0)) with pyro.plate("data", self.batch_size): pyro.sample("obs", dist.Bernoulli(p_latent), obs=self.data) return p_latent def guide(): alpha_q_log = pyro.param("alpha_q_log", self.log_alpha_n + 0.17) beta_q_log = pyro.param("beta_q_log", self.log_beta_n - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("p_latent", Beta(alpha_q, beta_q)) adam = optim.Adam({"lr": .001, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss=loss) for k in range(n_steps): svi.step() alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n) beta_error = param_abs_error("beta_q_log", self.log_beta_n) assert_equal(0.0, alpha_error, prec=0.08) assert_equal(0.0, beta_error, prec=0.08)
def test_reparam_stable(): data = dist.Poisson(torch.randn(8).exp()).sample() @poutine.reparam(config={ "dz": LatentStableReparam(), "y": LatentStableReparam() }) def model(): stability = pyro.sample("stability", dist.Uniform(1.0, 2.0)) trans_skew = pyro.sample("trans_skew", dist.Uniform(-1.0, 1.0)) obs_skew = pyro.sample("obs_skew", dist.Uniform(-1.0, 1.0)) scale = pyro.sample("scale", dist.Gamma(3, 1)) # We use separate plates because the .cumsum() op breaks independence. with pyro.plate("time1", len(data)): dz = pyro.sample("dz", dist.Stable(stability, trans_skew)) z = dz.cumsum(-1) with pyro.plate("time2", len(data)): y = pyro.sample("y", dist.Stable(stability, obs_skew, scale, z)) pyro.sample("x", dist.Poisson(y.abs()), obs=data) guide = AutoDelta(model) svi = SVI(model, guide, optim.Adam({"lr": 0.01}), Trace_ELBO()) for step in range(100): loss = svi.step() if step % 20 == 0: logger.info("step {} loss = {:0.4g}".format(step, loss))
def test_information_criterion(): # milk dataset: https://github.com/rmcelreath/rethinking/blob/master/data/milk.csv kcal = torch.tensor([ 0.49, 0.47, 0.56, 0.89, 0.92, 0.8, 0.46, 0.71, 0.68, 0.97, 0.84, 0.62, 0.54, 0.49, 0.48, 0.55, 0.71 ]) kcal_mean = kcal.mean() kcal_logstd = kcal.std().log() def model(): mu = pyro.sample("mu", dist.Normal(kcal_mean, 1)) log_sigma = pyro.sample("log_sigma", dist.Normal(kcal_logstd, 1)) with pyro.plate("plate"): pyro.sample("kcal", dist.Normal(mu, log_sigma.exp()), obs=kcal) delta_guide = AutoLaplaceApproximation(model) svi = SVI(model, delta_guide, optim.Adam({"lr": 0.05}), loss=Trace_ELBO(), num_samples=3000) for i in range(100): svi.step() svi.guide = delta_guide.laplace_approximation() posterior = svi.run() ic = posterior.information_criterion() assert_equal(ic["waic"], torch.tensor(-8.3), prec=0.2) assert_equal(ic["p_waic"], torch.tensor(1.8), prec=0.2)
def test_non_nested_plating_sum(): """Example from https://github.com/pyro-ppl/pyro/issues/2361""" # Generative model: data = x @ weights + eps def model(data, weights): loc = torch.tensor(1.0) scale = torch.tensor(0.1) # Sample latents (shares no dimensions with data) with pyro.plate('x_plate', weights.shape[0]): x = pyro.sample('x', pyro.distributions.Normal(loc, scale)) # Combine with weights and sample with pyro.plate('data_plate_1', data.shape[-1]): with pyro.plate('data_plate_2', data.shape[-2]): pyro.sample('data', pyro.distributions.Normal(x @ weights, scale), obs=data) def guide(data, weights): loc = pyro.param('x_loc', torch.tensor(0.5)) scale = torch.tensor(0.1) with pyro.plate('x_plate', weights.shape[0]): pyro.sample('x', pyro.distributions.Normal(loc, scale)) data = torch.randn([5, 3]) weights = torch.randn([2, 3]) adam = optim.Adam({"lr": 0.01}) loss_fn = RenyiELBO(num_particles=30, vectorize_particles=True) svi = SVI(model, guide, adam, loss_fn) for step in range(1): loss = svi.step(data, weights) if step % 20 == 0: logger.info("step {} loss = {:0.4g}".format(step, loss))
def test_auto_dirichlet(auto_class, Elbo): num_steps = 2000 prior = torch.tensor([0.5, 1.0, 1.5, 3.0]) data = torch.tensor([0] * 4 + [1] * 2 + [2] * 5).long() posterior = torch.tensor([4.5, 3.0, 6.5, 3.0]) def model(data): p = pyro.sample("p", dist.Dirichlet(prior)) with pyro.plate("data_plate"): pyro.sample("data", dist.Categorical(p).expand_by(data.shape), obs=data) guide = auto_class(model) svi = SVI(model, guide, optim.Adam({"lr": 0.003}), loss=Elbo()) for _ in range(num_steps): loss = svi.step(data) assert np.isfinite(loss), loss expected_mean = posterior / posterior.sum() if isinstance(guide, (AutoIAFNormal, AutoNormalizingFlow)): loc = guide.transform(torch.zeros(guide.latent_dim)) else: loc = guide.loc actual_mean = biject_to(constraints.simplex)(loc) assert_equal( actual_mean, expected_mean, prec=0.2, msg="".join( [ "\nexpected {}".format(expected_mean.detach().cpu().numpy()), "\n actual {}".format(actual_mean.detach().cpu().numpy()), ] ), )
def train_via_opt_svi(model, guide): pyro.clear_param_store() svi = SVI(model, guide, optim.Adam({"lr": lr}), loss=Trace_ELBO()) loss_list = [] mae_list = [] for step in range(n_steps): loss = svi.step(anime_matrix_train.values) pred = [] for i, j in anime_data_test[["user_id", "anime_id"]].itertuples(index=False): r = torch.dot(pyro.param("u_mean")[ i - 1, :], pyro.param("v_mean")[j - 1, :]) r = r.item() if r > 10: r = 10 pred.append(r) test_mae = mae(anime_data_test.rating, pred) if step > 1500 and test_mae - min(mae_list) > mae_tol: print('[stop at iter {}] loss: {:.4f} Test MAE: {:.4f}'.format( step, loss, test_mae) ) break loss_list.append(loss) mae_list.append(test_mae) if step % 250 == 0: print('[iter {}] loss: {:.4f} Test MAE: {:.4f}'.format( step, loss, test_mae)) return(loss_list, mae_list)
def main(args): # Init Pyro pyro.enable_validation(True) pyro.clear_param_store() # Load meta-data for all models and select model based on command arguments models = pyro_models.load() #model_dict = select_model(args, models) model_dict = models['arm.earnings_latin_square'] #model_dict = models['arm.election88_ch14'] # Define model/data/guide model = model_dict['model'] data = pyro_models.data(model_dict) guide = AutoDelta(model) # Perform variational inference ess = ESS(vectorize_particles=True, num_inner=1000, num_outer=10) svi = SVI(model, guide, optim.Adam({'lr': 0.1}), loss=Trace_ELBO()) for i in range(args.num_epochs): params = {} ess_val = ess.loss(model, guide, data, {}) loss = svi.step(data, params) print('loss', loss, 'ess', ess_val) sys.exit()
def test_auto_diagonal_gaussians(auto_class, Elbo): n_steps = 3501 if auto_class == AutoDiagonalNormal else 6001 def model(): pyro.sample("x", dist.Normal(-0.2, 1.2)) pyro.sample("y", dist.Normal(0.2, 0.7)) if auto_class is AutoLowRankMultivariateNormal: guide = auto_class(model, rank=1) else: guide = auto_class(model) adam = optim.Adam({"lr": .001, "betas": (0.95, 0.999)}) svi = SVI(model, guide, adam, loss=Elbo()) for k in range(n_steps): loss = svi.step() assert np.isfinite(loss), loss if auto_class is AutoLaplaceApproximation: guide = guide.laplace_approximation() loc, scale = guide._loc_scale() expected_loc = torch.tensor([-0.2, 0.2]) assert_equal(loc.detach(), expected_loc, prec=0.05, msg="\n".join(["Incorrect guide loc. Expected:", str(expected_loc.cpu().numpy()), "Actual:", str(loc.detach().cpu().numpy())])) expected_scale = torch.tensor([1.2, 0.7]) assert_equal(scale.detach(), expected_scale, prec=0.08, msg="\n".join(["Incorrect guide scale. Expected:", str(expected_scale.cpu().numpy()), "Actual:", str(scale.detach().cpu().numpy())]))
def do_test_auto(self, N, reparameterized, n_steps): logger.debug("\nGoing to do AutoGaussianChain test...") pyro.clear_param_store() self.setUp() self.setup_chain(N) self.compute_target(N) self.guide = AutoMultivariateNormal(self.model) logger.debug("target auto_loc: {}" .format(self.target_auto_mus[1:].detach().cpu().numpy())) logger.debug("target auto_diag_cov: {}" .format(self.target_auto_diag_cov[1:].detach().cpu().numpy())) # TODO speed up with parallel num_particles > 1 adam = optim.Adam({"lr": .001, "betas": (0.95, 0.999)}) svi = SVI(self.model, self.guide, adam, loss=Trace_ELBO()) for k in range(n_steps): loss = svi.step(reparameterized) assert np.isfinite(loss), loss if k % 1000 == 0 and k > 0 or k == n_steps - 1: logger.debug("[step {}] guide mean parameter: {}" .format(k, self.guide.loc.detach().cpu().numpy())) L = self.guide.scale_tril diag_cov = torch.mm(L, L.t()).diag() logger.debug("[step {}] auto_diag_cov: {}" .format(k, diag_cov.detach().cpu().numpy())) assert_equal(self.guide.loc.detach(), self.target_auto_mus[1:], prec=0.05, msg="guide mean off") assert_equal(diag_cov, self.target_auto_diag_cov[1:], prec=0.07, msg="guide covariance off")
def do_inference(self): pyro.clear_param_store() pyro.util.set_rng_seed(0) t0 = time.time() adam = optim.Adam({"lr": 0.005, "betas": (0.95, 0.999)}) svi = SVI(self.model, self.guide, adam, loss="ELBO", trace_graph=False, analytic_kl=False) losses = [] for k in range(100001): loss = svi.step(data) losses.append(loss) if k % 20 == 0 and k > 20: t_k = time.time() print("[epoch %05d] mean elbo: %.5f elapsed time: %.4f" % (k, -np.mean(losses[-100:]), t_k - t0)) print("[W] %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f" % (\ self.get_w_stats("top") + self.get_w_stats("mid") + self.get_w_stats("bottom") )) print("[Z] %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f" % (\ self.get_z_stats("top") + self.get_z_stats("mid") + self.get_z_stats("bottom") )) return results
def __init__(self, model: Type[torch.nn.Module], optimizer: Type[optim.PyroOptim] = None, loss: Type[infer.ELBO] = None, enumerate_parallel: bool = False, seed: int = 1, **kwargs: Union[str, float]) -> None: """ Initializes the trainer's parameters """ pyro.clear_param_store() set_deterministic_mode(seed) self.device = kwargs.get( "device", 'cuda' if torch.cuda.is_available() else 'cpu') if optimizer is None: lr = kwargs.get("lr", 1e-3) optimizer = optim.Adam({"lr": lr}) if loss is None: if enumerate_parallel: loss = infer.TraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False) else: loss = infer.Trace_ELBO() guide = model.guide if enumerate_parallel: guide = infer.config_enumerate(guide, "parallel", expand=True) self.svi = infer.SVI(model.model, guide, optimizer, loss=loss) self.loss_history = {"training_loss": [], "test_loss": []} self.current_epoch = 0
def test_auto_transform(auto_class): n_steps = 3500 def model(): pyro.sample("x", dist.LogNormal(0.2, 0.7)) if auto_class is AutoLowRankMultivariateNormal: guide = auto_class(model, rank=1) else: guide = auto_class(model) adam = optim.Adam({"lr": .001, "betas": (0.90, 0.999)}) svi = SVI(model, guide, adam, loss=Trace_ELBO()) for k in range(n_steps): loss = svi.step() assert np.isfinite(loss), loss if auto_class is AutoLaplaceApproximation: guide = guide.laplace_approximation() loc, scale = guide._loc_scale() assert_equal(loc, torch.tensor([0.2]), prec=0.04, msg="guide mean off") assert_equal(scale, torch.tensor([0.7]), prec=0.04, msg="guide covariance off")
def test_auto_dirichlet(auto_class, Elbo): num_steps = 2000 prior = torch.tensor([0.5, 1.0, 1.5, 3.0]) data = torch.tensor([0] * 4 + [1] * 2 + [2] * 5).long() posterior = torch.tensor([4.5, 3.0, 6.5, 3.0]) def model(data): p = pyro.sample("p", dist.Dirichlet(prior)) with pyro.plate("data_plate"): pyro.sample("data", dist.Categorical(p).expand_by(data.shape), obs=data) guide = auto_class(model) svi = SVI(model, guide, optim.Adam({"lr": .003}), loss=Elbo()) for _ in range(num_steps): loss = svi.step(data) assert np.isfinite(loss), loss expected_mean = posterior / posterior.sum() actual_mean = biject_to(constraints.simplex)(pyro.param("auto_loc")) assert_equal(actual_mean, expected_mean, prec=0.2, msg=''.join([ '\nexpected {}'.format( expected_mean.detach().cpu().numpy()), '\n actual {}'.format(actual_mean.detach().cpu().numpy()) ]))
def do_elbo_test(self, reparameterized, n_steps): pyro.clear_param_store() beta = dist.beta if reparameterized else fakes.nonreparameterized_beta def model(): p_latent = pyro.sample("p_latent", beta, self.alpha0, self.beta0) pyro.observe("obs", dist.bernoulli, self.data, p_latent) return p_latent def guide(): alpha_q_log = pyro.param( "alpha_q_log", Variable(self.log_alpha_n.data + 0.17, requires_grad=True)) beta_q_log = pyro.param( "beta_q_log", Variable(self.log_beta_n.data - 0.143, requires_grad=True)) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("p_latent", beta, alpha_q, beta_q) adam = optim.Adam({"lr": .001, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False) for k in range(n_steps): svi.step() alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n) beta_error = param_abs_error("beta_q_log", self.log_beta_n) assert_equal(0.0, alpha_error, prec=0.08) assert_equal(0.0, beta_error, prec=0.08)
def test_energy_distance_univariate(prior_scale): def model(data): loc = pyro.sample("loc", dist.Normal(0, 100)) scale = pyro.sample("scale", dist.LogNormal(0, 1)) with pyro.plate("data", len(data)): pyro.sample("obs", dist.Normal(loc, scale), obs=data) def guide(data): loc_loc = pyro.param("loc_loc", torch.tensor(0.)) loc_scale = pyro.param("loc_scale", torch.tensor(1.), constraint=constraints.positive) log_scale_loc = pyro.param("log_scale_loc", torch.tensor(0.)) log_scale_scale = pyro.param("log_scale_scale", torch.tensor(1.), constraint=constraints.positive) pyro.sample("loc", dist.Normal(loc_loc, loc_scale)) pyro.sample("scale", dist.LogNormal(log_scale_loc, log_scale_scale)) data = 10.0 + torch.randn(8) adam = optim.Adam({"lr": 0.1}) loss_fn = EnergyDistance(num_particles=32, prior_scale=prior_scale) svi = SVI(model, guide, adam, loss_fn) for step in range(2001): loss = svi.step(data) if step % 20 == 0: logger.info("step {} loss = {:0.4g}, loc = {:0.4g}, scale = {:0.4g}" .format(step, loss, pyro.param("loc_loc").item(), pyro.param("log_scale_loc").exp().item())) expected_loc = data.mean() expected_scale = data.std() actual_loc = pyro.param("loc_loc").detach() actual_scale = pyro.param("log_scale_loc").exp().detach() assert_close(actual_loc, expected_loc, atol=0.05) assert_close(actual_scale, expected_scale, rtol=0.1 if prior_scale else 0.05)
def do_elbo_test(self, reparameterized, n_steps): pyro.clear_param_store() pt_guide = LogNormalNormalGuide(self.log_mu_n.data + 0.17, self.log_tau_n.data - 0.143) def model(): mu_latent = pyro.sample("mu_latent", dist.normal, self.mu0, torch.pow(self.tau0, -0.5)) sigma = torch.pow(self.tau, -0.5) pyro.observe("obs0", dist.lognormal, self.data[0], mu_latent, sigma) pyro.observe("obs1", dist.lognormal, self.data[1], mu_latent, sigma) return mu_latent def guide(): pyro.module("mymodule", pt_guide) mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp( pt_guide.tau_q_log) sigma = torch.pow(tau_q, -0.5) normal = dist.normal if reparameterized else fakes.nonreparameterized_normal pyro.sample("mu_latent", normal, mu_q, sigma) adam = optim.Adam({"lr": .0005, "betas": (0.96, 0.999)}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False) for k in range(n_steps): svi.step() mu_error = param_abs_error("mymodule$$$mu_q_log", self.log_mu_n) tau_error = param_abs_error("mymodule$$$tau_q_log", self.log_tau_n) assert_equal(0.0, mu_error, prec=0.07) assert_equal(0.0, tau_error, prec=0.07)
def test_energy_distance_multivariate(prior_scale): def model(data): loc = torch.zeros(2) cov = pyro.sample("cov", dist.Normal(0, 100).expand([2, 2]).to_event(2)) with pyro.plate("data", len(data)): pyro.sample("obs", dist.MultivariateNormal(loc, cov), obs=data) def guide(data): scale_tril = pyro.param("scale_tril", torch.eye(2), constraint=constraints.lower_cholesky) pyro.sample("cov", dist.Delta(scale_tril @ scale_tril.t(), event_dim=2)) cov = torch.tensor([[1, 0.8], [0.8, 1]]) data = dist.MultivariateNormal(torch.zeros(2), cov).sample([10]) loss_fn = EnergyDistance(num_particles=32, prior_scale=prior_scale) svi = SVI(model, guide, optim.Adam({"lr": 0.1}), loss_fn) for step in range(2001): loss = svi.step(data) if step % 20 == 0: logger.info("step {} loss = {:0.4g}".format(step, loss)) delta = data - data.mean(0) expected_cov = (delta.t() @ delta) / len(data) scale_tril = pyro.param("scale_tril").detach() actual_cov = scale_tril @ scale_tril.t() assert_close(actual_cov, expected_cov, atol=0.2)
def test_elbo_with_transformed_distribution(self): pyro.clear_param_store() def model(): zero = Variable(torch.zeros(1)) one = Variable(torch.ones(1)) mu_latent = pyro.sample("mu_latent", dist.normal, self.mu0, torch.pow(self.tau0, -0.5)) bijector = AffineExp(torch.pow(self.tau, -0.5), mu_latent) x_dist = TransformedDistribution(dist.normal, bijector) pyro.observe("obs0", x_dist, self.data[0], zero, one) pyro.observe("obs1", x_dist, self.data[1], zero, one) return mu_latent def guide(): mu_q_log = pyro.param( "mu_q_log", Variable(self.log_mu_n.data + 0.17, requires_grad=True)) tau_q_log = pyro.param( "tau_q_log", Variable(self.log_tau_n.data - 0.143, requires_grad=True)) mu_q, tau_q = torch.exp(mu_q_log), torch.exp(tau_q_log) pyro.sample("mu_latent", dist.normal, mu_q, torch.pow(tau_q, -0.5)) adam = optim.Adam({"lr": .0005, "betas": (0.96, 0.999)}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False) for k in range(12001): svi.step() mu_error = param_abs_error("mu_q_log", self.log_mu_n) tau_error = param_abs_error("tau_q_log", self.log_tau_n) assert_equal(0.0, mu_error, prec=0.05) assert_equal(0.0, tau_error, prec=0.05)
def do_elbo_test(self, reparameterized, n_steps, loss): pyro.clear_param_store() Gamma = dist.Gamma if reparameterized else fakes.NonreparameterizedGamma def model(): lambda_latent = pyro.sample("lambda_latent", Gamma(self.alpha0, self.beta0)) with pyro.plate("data", self.n_data): pyro.sample("obs", dist.Poisson(lambda_latent), obs=self.data) return lambda_latent def guide(): alpha_q = pyro.param("alpha_q", self.alpha_n.detach() + math.exp(0.17), constraint=constraints.positive) beta_q = pyro.param("beta_q", self.beta_n.detach() / math.exp(0.143), constraint=constraints.positive) pyro.sample("lambda_latent", Gamma(alpha_q, beta_q)) adam = optim.Adam({"lr": .0002, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss) for k in range(n_steps): svi.step() assert_equal(pyro.param("alpha_q"), self.alpha_n, prec=0.2, msg='{} vs {}'.format( pyro.param("alpha_q").detach().cpu().numpy(), self.alpha_n.detach().cpu().numpy())) assert_equal(pyro.param("beta_q"), self.beta_n, prec=0.15, msg='{} vs {}'.format( pyro.param("beta_q").detach().cpu().numpy(), self.beta_n.detach().cpu().numpy()))
def do_elbo_test(self, reparameterized, n_steps): pyro.clear_param_store() def model(): mu_latent = pyro.sample("mu_latent", dist.normal, self.mu0, torch.pow(self.lam0, -0.5)) pyro.observe("obs", dist.normal, self.data, mu_latent, torch.pow(self.lam, -0.5)) return mu_latent def guide(): mu_q = pyro.param( "mu_q", Variable(self.analytic_mu_n.data + 0.134 * torch.ones(2), requires_grad=True)) log_sig_q = pyro.param( "log_sig_q", Variable(self.analytic_log_sig_n.data - 0.14 * torch.ones(2), requires_grad=True)) sig_q = torch.exp(log_sig_q) normal = dist.normal if reparameterized else fakes.nonreparameterized_normal pyro.sample("mu_latent", normal, mu_q, sig_q) adam = optim.Adam({"lr": .001}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False) for k in range(n_steps): svi.step() mu_error = param_mse("mu_q", self.analytic_mu_n) log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n) assert_equal(0.0, mu_error, prec=0.05) assert_equal(0.0, log_sig_error, prec=0.05)
def do_elbo_test(self, reparameterized, n_steps, loss): pyro.clear_param_store() def model(): loc_latent = pyro.sample("loc_latent", dist.Normal(self.loc0, torch.pow(self.lam0, -0.5)) .to_event(1)) with pyro.plate('data', self.batch_size): pyro.sample("obs", dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).to_event(1), obs=self.data) return loc_latent def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.detach() + 0.134) log_sig_q = pyro.param("log_sig_q", self.analytic_log_sig_n.data.detach() - 0.14) sig_q = torch.exp(log_sig_q) Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal pyro.sample("loc_latent", Normal(loc_q, sig_q).to_event(1)) adam = optim.Adam({"lr": .001}) svi = SVI(model, guide, adam, loss=loss) for k in range(n_steps): svi.step() loc_error = param_mse("loc_q", self.analytic_loc_n) log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n) assert_equal(0.0, loc_error, prec=0.05) assert_equal(0.0, log_sig_error, prec=0.05)
def run_svi(model, guide, iters, data, demand, num_samples=1000, filename=''): """ Runs SVI :param model: pyro model :param guide: pyro guide :param iters: iterations :param data: data to be passed to model, guide :param demand: demand to be passed to model, guide :param num_samples: number of samples for Monte Carlo posterior approximation :param filename: file to save pyro param store (.pkl) :return: svi object, and elbo loss """ pyro.clear_param_store() svi = SVI(model, guide, optim.Adam({"lr": .005}), loss=JitTrace_ELBO(), num_samples=num_samples) num_iters = iters if not smoke_test else 2 elbo_losses = [] for i in range(num_iters): elbo = svi.step(data, demand) elbo_losses.append(elbo) if i % 1000 == 0: logging.info("Elbo loss: {}".format(elbo)) if filename: pyro.get_param_store().save(filename) return svi, elbo_losses
def inference(train_x, train_y, dim_in, dim_out, batch_size, eval_fn=None, num_epochs=20000): """ NOTE : there must be a better way to feed dim_in/dim_out perhaps we could infer them from train_x, train_y? """ svi = SVI(model, guide, optim.Adam({'lr' : 0.005}), loss=Trace_ELBO(), num_samples=len(train_x) ) for i in range(num_epochs): if batch_size > 0: # random sample `batch_size` data points batch_x, batch_y = random_sample((train_x, train_y), batch_size) else: batch_x, batch_y = train_x, train_y # feed the whole training set # run a step of SVI elbo = svi.step(batch_x, batch_y, dim_in, dim_out) if i % 100 == 0: print('[{}/{}] Elbo loss : {}'.format(i, num_epochs, elbo)) if eval_fn: print('Evaluation Accuracy : ', eval_fn()) print('pyro\'s Param Store') for k, v in pyro.get_param_store().items(): print(k, v)
def test_auto_diagonal_gaussians(auto_class): n_steps = 3501 if auto_class == AutoDiagonalNormal else 6001 def model(): pyro.sample("x", dist.Normal(-0.2, 1.2)) pyro.sample("y", dist.Normal(0.2, 0.7)) if auto_class is AutoLowRankMultivariateNormal: guide = auto_class(model, rank=1) else: guide = auto_class(model) adam = optim.Adam({"lr": .001, "betas": (0.95, 0.999)}) svi = SVI(model, guide, adam, loss=Trace_ELBO()) for k in range(n_steps): loss = svi.step() assert np.isfinite(loss), loss loc, scale = guide._loc_scale() assert_equal(loc, torch.tensor([-0.2, 0.2]), prec=0.05, msg="guide mean off") assert_equal(scale, torch.tensor([1.21, 0.71]), prec=0.08, msg="guide covariance off")
def poisson_gamma_model(reparameterized, Elbo): pyro.set_rng_seed(0) alpha0 = torch.tensor(1.0) beta0 = torch.tensor(1.0) data = torch.tensor([1.0, 2.0, 3.0]) n_data = len(data) data_sum = data.sum(0) alpha_n = alpha0 + data_sum # posterior alpha beta_n = beta0 + torch.tensor(float(n_data)) # posterior beta log_alpha_n = torch.log(alpha_n) log_beta_n = torch.log(beta_n) pyro.clear_param_store() Gamma = dist.Gamma if reparameterized else fakes.NonreparameterizedGamma def model(): lambda_latent = pyro.sample("lambda_latent", Gamma(alpha0, beta0)) with pyro.plate("data", n_data): pyro.sample("obs", dist.Poisson(lambda_latent), obs=data) return lambda_latent def guide(): alpha_q_log = pyro.param("alpha_q_log", log_alpha_n + 0.17) beta_q_log = pyro.param("beta_q_log", log_beta_n - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("lambda_latent", Gamma(alpha_q, beta_q)) adam = optim.Adam({"lr": .0002, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss=Elbo()) for k in range(3000): svi.step()
def do_elbo_test(self, reparameterized, n_steps): pyro.clear_param_store() def model(): mu_latent = pyro.sample("mu_latent", dist.normal, self.mu0, torch.pow(self.lam0, -0.5)) pyro.map_data("aaa", self.data, lambda i, x: pyro.observe( "obs_%d" % i, dist.normal, x, mu_latent, torch.pow(self.lam, -0.5)), batch_size=self.batch_size) return mu_latent def guide(): mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.134 * torch.ones(2), requires_grad=True)) log_sig_q = pyro.param("log_sig_q", Variable( self.analytic_log_sig_n.data - 0.14 * torch.ones(2), requires_grad=True)) sig_q = torch.exp(log_sig_q) pyro.sample("mu_latent", dist.Normal(mu_q, sig_q, reparameterized=reparameterized)) pyro.map_data("aaa", self.data, lambda i, x: None, batch_size=self.batch_size) adam = optim.Adam({"lr": .001}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False) for k in range(n_steps): svi.step() mu_error = param_mse("mu_q", self.analytic_mu_n) log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n) self.assertEqual(0.0, mu_error, prec=0.05) self.assertEqual(0.0, log_sig_error, prec=0.05)