def test_elbo_nonreparameterized(self): pyro.clear_param_store() def model(): p_latent = pyro.sample("p_latent", dist.beta, self.alpha0, self.beta0) pyro.map_data("aaa", self.data, lambda i, x: pyro.observe( "obs_{}".format(i), dist.bernoulli, x, p_latent), batch_size=self.batch_size) return p_latent def guide(): alpha_q_log = pyro.param("alpha_q_log", Variable(self.log_alpha_n.data + 0.17, requires_grad=True)) beta_q_log = pyro.param("beta_q_log", Variable(self.log_beta_n.data - 0.143, requires_grad=True)) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("p_latent", dist.beta, alpha_q, beta_q) pyro.map_data("aaa", self.data, lambda i, x: None, batch_size=self.batch_size) adam = optim.Adam({"lr": .001, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False) for k in range(10001): svi.step() alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n) beta_error = param_abs_error("beta_q_log", self.log_beta_n) self.assertEqual(0.0, alpha_error, prec=0.08) self.assertEqual(0.0, beta_error, prec=0.08)
def do_elbo_test(self, reparameterized, n_steps): pyro.clear_param_store() pt_guide = LogNormalNormalGuide(self.log_mu_n.data + 0.17, self.log_tau_n.data - 0.143) def model(): mu_latent = pyro.sample("mu_latent", dist.normal, self.mu0, torch.pow(self.tau0, -0.5)) sigma = torch.pow(self.tau, -0.5) pyro.observe("obs0", dist.lognormal, self.data[0], mu_latent, sigma) pyro.observe("obs1", dist.lognormal, self.data[1], mu_latent, sigma) return mu_latent def guide(): pyro.module("mymodule", pt_guide) mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log) sigma = torch.pow(tau_q, -0.5) pyro.sample("mu_latent", dist.Normal(mu_q, sigma, reparameterized=reparameterized)) adam = optim.Adam({"lr": .0005, "betas": (0.96, 0.999)}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False) for k in range(n_steps): svi.step() mu_error = param_abs_error("mymodule$$$mu_q_log", self.log_mu_n) tau_error = param_abs_error("mymodule$$$tau_q_log", self.log_tau_n) self.assertEqual(0.0, mu_error, prec=0.07) self.assertEqual(0.0, tau_error, prec=0.07)
def test_elbo_nonreparameterized(self): pyro.clear_param_store() def model(): lambda_latent = pyro.sample("lambda_latent", dist.gamma, self.alpha0, self.beta0) pyro.observe("obs0", dist.exponential, self.data[0], lambda_latent) pyro.observe("obs1", dist.exponential, self.data[1], lambda_latent) return lambda_latent def guide(): alpha_q_log = pyro.param( "alpha_q_log", Variable(self.log_alpha_n.data + 0.17, requires_grad=True)) beta_q_log = pyro.param( "beta_q_log", Variable(self.log_beta_n.data - 0.143, requires_grad=True)) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("lambda_latent", dist.gamma, alpha_q, beta_q) adam = optim.Adam({"lr": .0003, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False) for k in range(10001): svi.step() alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n) beta_error = param_abs_error("beta_q_log", self.log_beta_n) self.assertEqual(0.0, alpha_error, prec=0.08) self.assertEqual(0.0, beta_error, prec=0.08)
def assert_ok(model, guide, elbo): """ Assert that inference works without warnings or errors. """ pyro.clear_param_store() inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo) inference.step()
def test_module_nn(nn_module): pyro.clear_param_store() nn_module = nn_module() assert pyro.get_param_store()._params == {} pyro.module("module", nn_module) for name in pyro.get_param_store().get_all_param_names(): assert pyro.params.user_param_name(name) in nn_module.state_dict().keys()
def test_bern_elbo_gradient(enum_discrete, trace_graph): pyro.clear_param_store() num_particles = 2000 def model(): p = Variable(torch.Tensor([0.25])) pyro.sample("z", dist.Bernoulli(p)) def guide(): p = pyro.param("p", Variable(torch.Tensor([0.5]), requires_grad=True)) pyro.sample("z", dist.Bernoulli(p)) print("Computing gradients using surrogate loss") Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO elbo = Elbo(enum_discrete=enum_discrete, num_particles=(1 if enum_discrete else num_particles)) with xfail_if_not_implemented(): elbo.loss_and_grads(model, guide) params = sorted(pyro.get_param_store().get_all_param_names()) assert params, "no params found" actual_grads = {name: pyro.param(name).grad.clone() for name in params} print("Computing gradients using finite difference") elbo = Trace_ELBO(num_particles=num_particles) expected_grads = finite_difference(lambda: elbo.loss(model, guide)) for name in params: print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data, expected_grads[name].data)) assert_equal(actual_grads, expected_grads, prec=0.1)
def test_iter_discrete_traces_vector(graph_type): pyro.clear_param_store() def model(): p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]]))) ps = pyro.param("ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]]))) x = pyro.sample("x", dist.Bernoulli(p)) y = pyro.sample("y", dist.Categorical(ps, one_hot=False)) assert x.size() == (2, 1) assert y.size() == (2, 1) return dict(x=x, y=y) traces = list(iter_discrete_traces(graph_type, model)) p = pyro.param("p").data ps = pyro.param("ps").data assert len(traces) == 2 * ps.size(-1) for scale, trace in traces: x = trace.nodes["x"]["value"].data.squeeze().long()[0] y = trace.nodes["y"]["value"].data.squeeze().long()[0] expected_scale = torch.exp(dist.Bernoulli(p).log_pdf(x) * dist.Categorical(ps, one_hot=False).log_pdf(y)) expected_scale = expected_scale.data.view(-1)[0] assert_equal(scale, expected_scale)
def test_random_module(self): pyro.clear_param_store() lifted_tr = poutine.trace(pyro.random_module("name", self.model, prior=self.prior)).get_trace() for name in lifted_tr.nodes.keys(): if lifted_tr.nodes[name]["type"] == "param": assert lifted_tr.nodes[name]["type"] == "sample" assert not lifted_tr.nodes[name]["is_observed"]
def test_gmm_batch_iter_discrete_traces(model, data_size, graph_type): pyro.clear_param_store() data = torch.arange(0, data_size) model = config_enumerate(model) traces = list(iter_discrete_traces(graph_type, model, data=data)) # This vectorized version is independent of data_size: assert len(traces) == 2
def test_dynamic_lr(scheduler, num_steps): pyro.clear_param_store() def model(): sample = pyro.sample('latent', Normal(torch.tensor(0.), torch.tensor(0.3))) return pyro.sample('obs', Normal(sample, torch.tensor(0.2)), obs=torch.tensor(0.1)) def guide(): loc = pyro.param('loc', torch.tensor(0.)) scale = pyro.param('scale', torch.tensor(0.5)) pyro.sample('latent', Normal(loc, scale)) svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO()) for epoch in range(2): scheduler.set_epoch(epoch) for _ in range(num_steps): svi.step() if epoch == 1: loc = pyro.param('loc') scale = pyro.param('scale') opt = scheduler.optim_objs[loc].optimizer assert opt.state_dict()['param_groups'][0]['lr'] == 0.02 assert opt.state_dict()['param_groups'][0]['initial_lr'] == 0.01 opt = scheduler.optim_objs[scale].optimizer assert opt.state_dict()['param_groups'][0]['lr'] == 0.02 assert opt.state_dict()['param_groups'][0]['initial_lr'] == 0.01
def test_dirichlet_bernoulli(Elbo, vectorized): pyro.clear_param_store() data = torch.tensor([1.0] * 6 + [0.0] * 4) def model1(data): concentration0 = torch.tensor([10.0, 10.0]) f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1] for i in pyro.irange("irange", len(data)): pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i]) def model2(data): concentration0 = torch.tensor([10.0, 10.0]) f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1] pyro.sample("obs", dist.Bernoulli(f).expand_by(data.shape).independent(1), obs=data) model = model2 if vectorized else model1 def guide(data): concentration_q = pyro.param("concentration_q", torch.tensor([15.0, 15.0]), constraint=constraints.positive) pyro.sample("latent_fairness", dist.Dirichlet(concentration_q)) elbo = Elbo(num_particles=7, strict_enumeration_warning=False) optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)}) svi = SVI(model, guide, optim, elbo) for step in range(40): svi.step(data)
def test_elbo_bern(quantity, enumerate1): pyro.clear_param_store() num_particles = 1 if enumerate1 else 10000 prec = 0.001 if enumerate1 else 0.1 q = pyro.param("q", torch.tensor(0.5, requires_grad=True)) kl = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(0.25)) def model(): with pyro.iarange("particles", num_particles): pyro.sample("z", dist.Bernoulli(0.25).expand_by([num_particles])) @config_enumerate(default=enumerate1) def guide(): q = pyro.param("q") with pyro.iarange("particles", num_particles): pyro.sample("z", dist.Bernoulli(q).expand_by([num_particles])) elbo = TraceEnum_ELBO(max_iarange_nesting=1, strict_enumeration_warning=any([enumerate1])) if quantity == "loss": actual = elbo.loss(model, guide) / num_particles expected = kl.item() assert_equal(actual, expected, prec=prec, msg="".join([ "\nexpected = {}".format(expected), "\n actual = {}".format(actual), ])) else: elbo.loss_and_grads(model, guide) actual = q.grad / num_particles expected = grad(kl, [q])[0] assert_equal(actual, expected, prec=prec, msg="".join([ "\nexpected = {}".format(expected.detach().cpu().numpy()), "\n actual = {}".format(actual.detach().cpu().numpy()), ]))
def do_elbo_test(self, reparameterized, n_steps): pyro.clear_param_store() def model(): mu_latent = pyro.sample("mu_latent", dist.normal, self.mu0, torch.pow(self.lam0, -0.5)) pyro.map_data("aaa", self.data, lambda i, x: pyro.observe( "obs_%d" % i, dist.normal, x, mu_latent, torch.pow(self.lam, -0.5)), batch_size=self.batch_size) return mu_latent def guide(): mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.134 * torch.ones(2), requires_grad=True)) log_sig_q = pyro.param("log_sig_q", Variable( self.analytic_log_sig_n.data - 0.14 * torch.ones(2), requires_grad=True)) sig_q = torch.exp(log_sig_q) pyro.sample("mu_latent", dist.Normal(mu_q, sig_q, reparameterized=reparameterized)) pyro.map_data("aaa", self.data, lambda i, x: None, batch_size=self.batch_size) adam = optim.Adam({"lr": .001}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False) for k in range(n_steps): svi.step() mu_error = param_mse("mu_q", self.analytic_mu_n) log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n) self.assertEqual(0.0, mu_error, prec=0.05) self.assertEqual(0.0, log_sig_error, prec=0.05)
def test_gmm_iter_discrete_traces(data_size, graph_type, model): pyro.clear_param_store() data = torch.arange(0, data_size) model = config_enumerate(model) traces = list(iter_discrete_traces(graph_type, model, data=data, verbose=True)) # This non-vectorized version is exponential in data_size: assert len(traces) == 2**data_size
def main(args): pyro.clear_param_store() data = build_linear_dataset(N, p) if args.cuda: # make tensors and modules CUDA data = data.cuda() softplus.cuda() regression_model.cuda() for j in range(args.num_epochs): if args.batch_size == N: # use the entire data set epoch_loss = svi.step(data) else: # mini batch epoch_loss = 0.0 perm = torch.randperm(N) if not args.cuda else torch.randperm(N).cuda() # shuffle data data = data[perm] # get indices of each batch all_batches = get_batch_indices(N, args.batch_size) for ix, batch_start in enumerate(all_batches[:-1]): batch_end = all_batches[ix + 1] batch_data = data[batch_start: batch_end] epoch_loss += svi.step(batch_data) if j % 100 == 0: print("epoch avg loss {}".format(epoch_loss/float(N)))
def test_elbo_hmm_in_guide(enumerate1, num_steps): pyro.clear_param_store() data = torch.ones(num_steps) init_probs = torch.tensor([0.5, 0.5]) def model(data): transition_probs = pyro.param("transition_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) emission_probs = pyro.param("emission_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) x = None for i, y in enumerate(data): probs = init_probs if x is None else transition_probs[x] x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=y) @config_enumerate(default=enumerate1) def guide(data): transition_probs = pyro.param("transition_probs", torch.tensor([[0.75, 0.25], [0.25, 0.75]]), constraint=constraints.simplex) x = None for i, y in enumerate(data): probs = init_probs if x is None else transition_probs[x] x = pyro.sample("x_{}".format(i), dist.Categorical(probs)) elbo = TraceEnum_ELBO(max_iarange_nesting=0) elbo.loss_and_grads(model, guide, data) # These golden values simply test agreement between parallel and sequential. expected_grads = { 2: { "transition_probs": [[0.1029949, -0.1029949], [0.1029949, -0.1029949]], "emission_probs": [[0.75, -0.75], [0.25, -0.25]], }, 3: { "transition_probs": [[0.25748726, -0.25748726], [0.25748726, -0.25748726]], "emission_probs": [[1.125, -1.125], [0.375, -0.375]], }, 10: { "transition_probs": [[1.64832076, -1.64832076], [1.64832076, -1.64832076]], "emission_probs": [[3.75, -3.75], [1.25, -1.25]], }, 20: { "transition_probs": [[3.70781687, -3.70781687], [3.70781687, -3.70781687]], "emission_probs": [[7.5, -7.5], [2.5, -2.5]], }, } for name, value in pyro.get_param_store().named_parameters(): actual = value.grad expected = torch.tensor(expected_grads[num_steps][name]) assert_equal(actual, expected, msg=''.join([ '\nexpected {}.grad = {}'.format(name, expected.cpu().numpy()), '\n actual {}.grad = {}'.format(name, actual.detach().cpu().numpy()), ]))
def test_duplicate_obs_name(self): pyro.clear_param_store() adam = optim.Adam({"lr": .001}) svi = SVI(self.duplicate_obs, self.guide, adam, loss="ELBO", trace_graph=False) with pytest.raises(RuntimeError): svi.step()
def assert_error(model, guide, elbo): """ Assert that inference fails with an error. """ pyro.clear_param_store() inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo) with pytest.raises((NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError)): inference.step()
def test_extra_samples(self): pyro.clear_param_store() adam = optim.Adam({"lr": .001}) svi = SVI(self.model, self.guide, adam, loss="ELBO", trace_graph=False) with pytest.warns(Warning): svi.step()
def test_svi_step_smoke(model, guide, enum_discrete, trace_graph): pyro.clear_param_store() data = Variable(torch.Tensor([0, 1, 9])) optimizer = pyro.optim.Adam({"lr": .001}) inference = SVI(model, guide, optimizer, loss="ELBO", trace_graph=trace_graph, enum_discrete=enum_discrete) with xfail_if_not_implemented(): inference.step(data)
def test_random_module(nn_module): pyro.clear_param_store() nn_module = nn_module() p = torch.ones(2, 2) prior = dist.Bernoulli(p) lifted_mod = pyro.random_module("module", nn_module, prior) nn_module = lifted_mod() for name, parameter in nn_module.named_parameters(): assert torch.equal(torch.ones(2, 2), parameter.data)
def do_elbo_test(self, reparameterized, n_steps, lr, prec, beta1, difficulty=1.0, model_permutation=False): n_repa_nodes = torch.sum(self.which_nodes_reparam) if not reparameterized \ else len(self.q_topo_sort) logger.info((" - - - DO GAUSSIAN %d-LAYERED PYRAMID ELBO TEST " + "(with a total of %d RVs) [reparameterized=%s; %d/%d; perm=%s] - - -") % (self.N, (2 ** self.N) - 1, reparameterized, n_repa_nodes, len(self.q_topo_sort), model_permutation)) pyro.clear_param_store() # check graph structure is as expected but only for N=2 if self.N == 2: guide_trace = pyro.poutine.trace(self.guide, graph_type="dense").get_trace(reparameterized=reparameterized, model_permutation=model_permutation, difficulty=difficulty) expected_nodes = set(['log_sig_1R', 'kappa_1_1L', '_INPUT', 'constant_term_loc_latent_1R', '_RETURN', 'loc_latent_1R', 'loc_latent_1', 'constant_term_loc_latent_1', 'loc_latent_1L', 'constant_term_loc_latent_1L', 'log_sig_1L', 'kappa_1_1R', 'kappa_1R_1L', 'log_sig_1']) expected_edges = set([('loc_latent_1R', 'loc_latent_1'), ('loc_latent_1L', 'loc_latent_1R'), ('loc_latent_1L', 'loc_latent_1')]) assert expected_nodes == set(guide_trace.nodes) assert expected_edges == set(guide_trace.edges) adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)}) svi = SVI(self.model, self.guide, adam, loss=TraceGraph_ELBO()) for step in range(n_steps): t0 = time.time() svi.step(reparameterized=reparameterized, model_permutation=model_permutation, difficulty=difficulty) if step % 5000 == 0 or step == n_steps - 1: log_sig_errors = [] for node in self.target_lambdas: target_log_sig = -0.5 * torch.log(self.target_lambdas[node]) log_sig_error = param_mse('log_sig_' + node, target_log_sig) log_sig_errors.append(log_sig_error) max_log_sig_error = np.max(log_sig_errors) min_log_sig_error = np.min(log_sig_errors) mean_log_sig_error = np.mean(log_sig_errors) leftmost_node = self.q_topo_sort[0] leftmost_constant_error = param_mse('constant_term_' + leftmost_node, self.target_leftmost_constant) almost_leftmost_constant_error = param_mse('constant_term_' + leftmost_node[:-1] + 'R', self.target_almost_leftmost_constant) logger.debug("[mean function constant errors (partial)] %.4f %.4f" % (leftmost_constant_error, almost_leftmost_constant_error)) logger.debug("[min/mean/max log(scale) errors] %.4f %.4f %.4f" % (min_log_sig_error, mean_log_sig_error, max_log_sig_error)) logger.debug("[step time = %.3f; N = %d; step = %d]\n" % (time.time() - t0, self.N, step)) assert_equal(0.0, max_log_sig_error, prec=prec) assert_equal(0.0, leftmost_constant_error, prec=prec) assert_equal(0.0, almost_leftmost_constant_error, prec=prec)
def pytest_runtest_setup(item): pyro.clear_param_store() if item.get_marker("disable_validation"): pyro.enable_validation(False) else: pyro.enable_validation(True) test_initialize_marker = item.get_marker("init") if test_initialize_marker: rng_seed = test_initialize_marker.kwargs["rng_seed"] pyro.set_rng_seed(rng_seed)
def test_svi_step_smoke(model, guide, enumerate1): pyro.clear_param_store() data = torch.tensor([0.0, 1.0, 9.0]) guide = config_enumerate(guide, default=enumerate1) optimizer = pyro.optim.Adam({"lr": .001}) elbo = TraceEnum_ELBO(max_iarange_nesting=1, strict_enumeration_warning=any([enumerate1])) inference = SVI(model, guide, optimizer, loss=elbo) inference.step(data)
def test_non_mean_field_bern_normal_elbo_gradient(enumerate1, pi1, pi2, pi3, include_z=True): pyro.clear_param_store() num_particles = 10000 def model(): with pyro.iarange("particles", num_particles): q3 = pyro.param("q3", torch.tensor(pi3, requires_grad=True)) y = pyro.sample("y", dist.Bernoulli(q3).expand_by([num_particles])) if include_z: pyro.sample("z", dist.Normal(0.55 * y + q3, 1.0)) def guide(): q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True)) q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True)) with pyro.iarange("particles", num_particles): y = pyro.sample("y", dist.Bernoulli(q1).expand_by([num_particles]), infer={"enumerate": enumerate1}) if include_z: pyro.sample("z", dist.Normal(q2 * y + 0.10, 1.0)) logger.info("Computing gradients using surrogate loss") elbo = TraceEnum_ELBO(max_iarange_nesting=1, strict_enumeration_warning=any([enumerate1])) elbo.loss_and_grads(model, guide) actual_grad_q1 = pyro.param('q1').grad / num_particles if include_z: actual_grad_q2 = pyro.param('q2').grad / num_particles actual_grad_q3 = pyro.param('q3').grad / num_particles logger.info("Computing analytic gradients") q1 = torch.tensor(pi1, requires_grad=True) q2 = torch.tensor(pi2, requires_grad=True) q3 = torch.tensor(pi3, requires_grad=True) elbo = kl_divergence(dist.Bernoulli(q1), dist.Bernoulli(q3)) if include_z: elbo = elbo + q1 * kl_divergence(dist.Normal(q2 + 0.10, 1.0), dist.Normal(q3 + 0.55, 1.0)) elbo = elbo + (1.0 - q1) * kl_divergence(dist.Normal(0.10, 1.0), dist.Normal(q3, 1.0)) expected_grad_q1, expected_grad_q2, expected_grad_q3 = grad(elbo, [q1, q2, q3]) else: expected_grad_q1, expected_grad_q3 = grad(elbo, [q1, q3]) prec = 0.04 if enumerate1 is None else 0.02 assert_equal(actual_grad_q1, expected_grad_q1, prec=prec, msg="".join([ "\nq1 expected = {}".format(expected_grad_q1.data.cpu().numpy()), "\nq1 actual = {}".format(actual_grad_q1.data.cpu().numpy()), ])) if include_z: assert_equal(actual_grad_q2, expected_grad_q2, prec=prec, msg="".join([ "\nq2 expected = {}".format(expected_grad_q2.data.cpu().numpy()), "\nq2 actual = {}".format(actual_grad_q2.data.cpu().numpy()), ])) assert_equal(actual_grad_q3, expected_grad_q3, prec=prec, msg="".join([ "\nq3 expected = {}".format(expected_grad_q3.data.cpu().numpy()), "\nq3 actual = {}".format(actual_grad_q3.data.cpu().numpy()), ]))
def test_random_module_prior_dict(self): pyro.clear_param_store() lifted_nn = pyro.random_module("name", self.model, prior=self.nn_prior) lifted_tr = poutine.trace(lifted_nn).get_trace() for key_name in lifted_tr.nodes.keys(): name = pyro.params.user_param_name(key_name) if name in {'fc.weight', 'fc.prior'}: dist_name = name[3:] assert dist_name + "_prior" == lifted_tr.nodes[key_name]['fn'].__name__ assert lifted_tr.nodes[key_name]["type"] == "sample" assert not lifted_tr.nodes[key_name]["is_observed"]
def assert_warning(model, guide, elbo): """ Assert that inference works but with a warning. """ pyro.clear_param_store() inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") inference.step() assert len(w), 'No warnings were raised' for warning in w: logger.info(warning)
def setUp(self): pyro.clear_param_store() def mu1_prior(tensor, *args, **kwargs): flat_tensor = tensor.view(-1) m = Variable(torch.zeros(flat_tensor.size(0))) s = Variable(torch.ones(flat_tensor.size(0))) return Normal(m, s).sample().view(tensor.size()) def sigma1_prior(tensor, *args, **kwargs): flat_tensor = tensor.view(-1) m = Variable(torch.zeros(flat_tensor.size(0))) s = Variable(torch.ones(flat_tensor.size(0))) return Normal(m, s).sample().view(tensor.size()) def mu2_prior(tensor, *args, **kwargs): flat_tensor = tensor.view(-1) m = Variable(torch.zeros(flat_tensor.size(0))) return Bernoulli(m).sample().view(tensor.size()) def sigma2_prior(tensor, *args, **kwargs): return sigma1_prior(tensor) def bias_prior(tensor, *args, **kwargs): return mu2_prior(tensor) def weight_prior(tensor, *args, **kwargs): return sigma1_prior(tensor) def stoch_fn(tensor, *args, **kwargs): mu = Variable(torch.zeros(tensor.size())) sigma = Variable(torch.ones(tensor.size())) return pyro.sample("sample", Normal(mu, sigma)) def guide(): mu1 = pyro.param("mu1", Variable(torch.randn(2), requires_grad=True)) sigma1 = pyro.param("sigma1", Variable(torch.ones(2), requires_grad=True)) pyro.sample("latent1", Normal(mu1, sigma1)) mu2 = pyro.param("mu2", Variable(torch.randn(2), requires_grad=True)) sigma2 = pyro.param("sigma2", Variable(torch.ones(2), requires_grad=True)) latent2 = pyro.sample("latent2", Normal(mu2, sigma2)) return latent2 self.model = Model() self.guide = guide self.prior = mu1_prior self.prior_dict = {"mu1": mu1_prior, "sigma1": sigma1_prior, "mu2": mu2_prior, "sigma2": sigma2_prior} self.partial_dict = {"mu1": mu1_prior, "sigma1": sigma1_prior} self.nn_prior = {"fc.bias": bias_prior, "fc.weight": weight_prior} self.fn = stoch_fn self.data = Variable(torch.randn(2, 2))
def setUp(self): pyro.clear_param_store() def loc1_prior(tensor, *args, **kwargs): flat_tensor = tensor.view(-1) m = torch.zeros(flat_tensor.size(0)) s = torch.ones(flat_tensor.size(0)) return Normal(m, s).sample().view(tensor.size()) def scale1_prior(tensor, *args, **kwargs): flat_tensor = tensor.view(-1) m = torch.zeros(flat_tensor.size(0)) s = torch.ones(flat_tensor.size(0)) return Normal(m, s).sample().view(tensor.size()).exp() def loc2_prior(tensor, *args, **kwargs): flat_tensor = tensor.view(-1) m = torch.zeros(flat_tensor.size(0)) return Bernoulli(m).sample().view(tensor.size()) def scale2_prior(tensor, *args, **kwargs): return scale1_prior(tensor) def bias_prior(tensor, *args, **kwargs): return loc2_prior(tensor) def weight_prior(tensor, *args, **kwargs): return scale1_prior(tensor) def stoch_fn(tensor, *args, **kwargs): loc = torch.zeros(tensor.size()) scale = torch.ones(tensor.size()) return pyro.sample("sample", Normal(loc, scale)) def guide(): loc1 = pyro.param("loc1", torch.randn(2, requires_grad=True)) scale1 = pyro.param("scale1", torch.ones(2, requires_grad=True)) pyro.sample("latent1", Normal(loc1, scale1)) loc2 = pyro.param("loc2", torch.randn(2, requires_grad=True)) scale2 = pyro.param("scale2", torch.ones(2, requires_grad=True)) latent2 = pyro.sample("latent2", Normal(loc2, scale2)) return latent2 self.model = Model() self.guide = guide self.prior = scale1_prior self.prior_dict = {"loc1": loc1_prior, "scale1": scale1_prior, "loc2": loc2_prior, "scale2": scale2_prior} self.partial_dict = {"loc1": loc1_prior, "scale1": scale1_prior} self.nn_prior = {"fc.bias": bias_prior, "fc.weight": weight_prior} self.fn = stoch_fn self.data = torch.randn(2, 2)
def test_named_dict(): pyro.clear_param_store() def model(): latent = named.Dict("latent") loc = latent["loc"].param_(torch.zeros(1)) foo = latent["foo"].sample_(dist.Normal(loc, torch.ones(1))) latent["bar"].sample_(dist.Normal(loc, torch.ones(1)), obs=foo) latent["x"].z.sample_(dist.Normal(loc, torch.ones(1))) tr = poutine.trace(model).get_trace() assert get_sample_names(tr) == set(["latent['foo']", "latent['x'].z"]) assert get_observe_names(tr) == set(["latent['bar']"]) assert get_param_names(tr) == set(["latent['loc']"])
def main(smoke_test=False): epochs = 2 if smoke_test == True else 50 batch_size = 128 seed = 0 x_ch = 1 z_dim = 32 # 乱数シード初期化 torch.manual_seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) pyro.set_rng_seed(seed) date_and_time = datetime.datetime.now().strftime('%Y-%m%d-%H%M') save_root = f'./results/pyro/{date_and_time}' if not os.path.exists(save_root): os.makedirs(save_root) if torch.cuda.is_available(): device = 'cuda:0' else: device = 'cpu' pyro.clear_param_store() # Pyroのパラメーター初期化 pyro.enable_validation(smoke_test) # デバッグ用。NaNチェック、分布の検証、引数やサポート値チェックなど pyro.distributions.enable_validation(False) root = '/mnt/hdd/sika/Datasets' train_loader, test_loader = make_MNIST_loader(root, batch_size=batch_size) # modelメソッドとguideメソッドを持つクラスのインスタンスを作成 vae = VAE(x_ch, z_dim).to(device) # 最適化アルゴリズムはPyroOptimでラッピングして使用する optimizer = pyro.optim.PyroOptim(torch.optim.Adam, {'lr': 1e-3}) # SVI(Stochastic Variational Inference)のインスタンスを作成 svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO()) x_fixed, _ = next(iter(test_loader)) # 固定の画像 x_fixed = x_fixed[:8].to(device) z_fixed = torch.randn([64, z_dim], device=device) # 固定の潜在変数 x_dummy = torch.zeros(64, x_fixed.size(1), x_fixed.size(2), x_fixed.size(3), device=device) # sample用 train_loss_list, test_loss_list = [], [] for epoch in range(1, epochs + 1): train_loss_list.append( learn(svi, epoch, train_loader, device, train=True)) test_loss_list.append( learn(svi, epoch, test_loader, device, train=False)) print(f' [Epoch {epoch}] train loss {train_loss_list[-1]:.4f}') print(f' [Epoch {epoch}] test loss {test_loss_list[-1]:.4f}\n') # 損失値のグラフを作成し保存 plt.plot(list(range(1, epoch + 1)), train_loss_list, label='train') plt.plot(list(range(1, epoch + 1)), test_loss_list, label='test') plt.xlabel('epochs') plt.ylabel('loss') plt.legend() plt.savefig(os.path.join(save_root, 'loss.png')) plt.close() # 再構成画像 x_reconst = reconstruct_image(vae.encoder, vae.decoder, x_fixed) save_image(torch.cat([x_fixed, x_reconst], dim=0), os.path.join(save_root, f'reconst_{epoch}.png'), nrow=8) # 補間画像 x_interpol = interpolate_image(vae.encoder, vae.decoder, x_fixed) save_image(x_interpol, os.path.join(save_root, f'interpol_{epoch}.png'), nrow=8) # 生成画像(潜在変数固定) x_generate = generate_image(vae.decoder, z_fixed) save_image(x_generate, os.path.join(save_root, f'generate_{epoch}.png'), nrow=8) # 生成画像(ランダムサンプリング) x_sample = sample_image(vae.model, x_dummy) save_image(x_sample, os.path.join(save_root, f'sample_{epoch}.png'), nrow=8)
def main(args): pyro.set_rng_seed(0) pyro.clear_param_store() pyro.enable_validation(__debug__) # load data if args.dataset == "dipper": capture_history_file = os.path.dirname( os.path.abspath(__file__)) + '/dipper_capture_history.csv' elif args.dataset == "vole": capture_history_file = os.path.dirname( os.path.abspath(__file__)) + '/meadow_voles_capture_history.csv' else: raise ValueError("Available datasets are \'dipper\' and \'vole\'.") capture_history = torch.tensor( np.genfromtxt(capture_history_file, delimiter=',')).float()[:, 1:] N, T = capture_history.shape print( "Loaded {} capture history for {} individuals collected over {} time periods." .format(args.dataset, N, T)) if args.dataset == "dipper" and args.model in ["4", "5"]: sex_file = os.path.dirname( os.path.abspath(__file__)) + '/dipper_sex.csv' sex = torch.tensor(np.genfromtxt(sex_file, delimiter=',')).float()[:, 1] print("Loaded dipper sex data.") elif args.dataset == "vole" and args.model in ["4", "5"]: raise ValueError( "Cannot run model_{} on meadow voles data, since we lack sex " + "information for these animals.".format(args.model)) else: sex = None model = models[args.model] # we use poutine.block to only expose the continuous latent variables # in the models to AutoDiagonalNormal (all of which begin with 'phi' # or 'rho') def expose_fn(msg): return msg["name"][0:3] in ['phi', 'rho'] # we use a mean field diagonal normal variational distributions (i.e. guide) # for the continuous latent variables. guide = AutoDiagonalNormal(poutine.block(model, expose_fn=expose_fn)) # since we enumerate the discrete random variables, # we need to use TraceEnum_ELBO or TraceTMC_ELBO. optim = Adam({'lr': args.learning_rate}) if args.tmc: elbo = TraceTMC_ELBO(max_plate_nesting=1) tmc_model = poutine.infer_config(model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {} ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: elbo = TraceEnum_ELBO(max_plate_nesting=1, num_particles=20, vectorize_particles=True) svi = SVI(model, guide, optim, elbo) losses = [] print( "Beginning training of model_{} with Stochastic Variational Inference." .format(args.model)) for step in range(args.num_steps): loss = svi.step(capture_history, sex) losses.append(loss) if step % 20 == 0 and step > 0 or step == args.num_steps - 1: print("[iteration %03d] loss: %.3f" % (step, np.mean(losses[-20:]))) # evaluate final trained model elbo_eval = TraceEnum_ELBO(max_plate_nesting=1, num_particles=2000, vectorize_particles=True) svi_eval = SVI(model, guide, optim, elbo_eval) print("Final loss: %.4f" % svi_eval.evaluate_loss(capture_history, sex))
def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False, lr=0.001): pyro.clear_param_store() def model(): with pyro.plate("samples", self.sample_batch_size): pyro.sample( "loc_latent", dist.Normal( torch.stack([self.loc0] * self.sample_batch_size, dim=0), torch.stack([torch.pow(self.lam0, -0.5)] * self.sample_batch_size, dim=0), ).to_event(1), ) def guide(): loc_q = pyro.param("loc_q", self.loc0.detach() + 0.134) log_sig_q = pyro.param( "log_sig_q", -0.5 * torch.log(self.lam0).data.detach() - 0.14) sig_q = torch.exp(log_sig_q) Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal with pyro.plate("samples", self.sample_batch_size): pyro.sample( "loc_latent", Normal( torch.stack([loc_q] * self.sample_batch_size, dim=0), torch.stack([sig_q] * self.sample_batch_size, dim=0), ).to_event(1), ) adam = optim.Adam({"lr": lr}) svi = SVI(model, guide, adam, loss=loss) alpha = 0.99 for k in range(n_steps): svi.step() if debug: loc_error = param_mse("loc_q", self.loc0) log_sig_error = param_mse("log_sig_q", -0.5 * torch.log(self.lam0)) with torch.no_grad(): if k == 0: ( avg_loglikelihood, avg_penalty, ) = loss._differentiable_loss_parts(model, guide) avg_loglikelihood = torch_item(avg_loglikelihood) avg_penalty = torch_item(avg_penalty) loglikelihood, penalty = loss._differentiable_loss_parts( model, guide) avg_loglikelihood = alpha * avg_loglikelihood + ( 1 - alpha) * torch_item(loglikelihood) avg_penalty = alpha * avg_penalty + ( 1 - alpha) * torch_item(penalty) if k % 100 == 0: print(loc_error, log_sig_error) print(avg_loglikelihood, avg_penalty) print() loc_error = param_mse("loc_q", self.loc0) log_sig_error = param_mse("log_sig_q", -0.5 * torch.log(self.lam0)) assert_equal(0.0, loc_error, prec=0.05) assert_equal(0.0, log_sig_error, prec=0.05)
def run_inference(dataset_obj: SingleCellRNACountsDataset, args) -> RemoveBackgroundPyroModel: """Run a full inference procedure, training a latent variable model. Args: dataset_obj: Input data in the form of a SingleCellRNACountsDataset object. args: Input command line parsed arguments. Returns: model: cellbender.model.RemoveBackgroundPyroModel that has had inference run. """ # Get the trimmed count matrix (transformed if called for). count_matrix = dataset_obj.get_count_matrix() # Configure pyro options (skip validations to improve speed). pyro.enable_validation(False) pyro.distributions.enable_validation(False) pyro.set_rng_seed(0) pyro.clear_param_store() # Set up the variational autoencoder: # Encoder. encoder_z = EncodeZ(input_dim=count_matrix.shape[1], hidden_dims=args.z_hidden_dims, output_dim=args.z_dim, input_transform='normalize') encoder_other = EncodeNonZLatents( n_genes=count_matrix.shape[1], z_dim=args.z_dim, hidden_dims=consts.ENC_HIDDEN_DIMS, log_count_crossover=dataset_obj.priors['log_counts_crossover'], prior_log_cell_counts=np.log1p(dataset_obj.priors['cell_counts']), input_transform='normalize') encoder = CompositeEncoder({'z': encoder_z, 'other': encoder_other}) # Decoder. decoder = Decoder(input_dim=args.z_dim, hidden_dims=args.z_hidden_dims[::-1], output_dim=count_matrix.shape[1]) # Set up the pyro model for variational inference. model = RemoveBackgroundPyroModel(model_type=args.model, encoder=encoder, decoder=decoder, dataset_obj=dataset_obj, use_cuda=args.use_cuda) # Load the dataset into DataLoaders. frac = args.training_fraction # Fraction of barcodes to use for training batch_size = int( min(300, frac * dataset_obj.analyzed_barcode_inds.size / 2)) train_loader, test_loader = \ prep_data_for_training(dataset=count_matrix, empty_drop_dataset= dataset_obj.get_count_matrix_empties(), random_state=dataset_obj.random, batch_size=batch_size, training_fraction=frac, fraction_empties=args.fraction_empties, shuffle=True, use_cuda=args.use_cuda) # Set up the optimizer. optimizer = pyro.optim.clipped_adam.ClippedAdam optimizer_args = {'lr': args.learning_rate, 'clip_norm': 10.} # Set up a learning rate scheduler. minibatches_per_epoch = int( np.ceil(len(train_loader) / train_loader.batch_size).item()) scheduler_args = { 'optimizer': optimizer, 'max_lr': args.learning_rate * 10, 'steps_per_epoch': minibatches_per_epoch, 'epochs': args.epochs, 'optim_args': optimizer_args } scheduler = pyro.optim.OneCycleLR(scheduler_args) # Determine the loss function. if args.use_jit: # Call guide() once as a warm-up. model.guide( torch.zeros([10, dataset_obj.analyzed_gene_inds.size ]).to(model.device)) if args.model == "simple": loss_function = JitTrace_ELBO() else: loss_function = JitTraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False) else: if args.model == "simple": loss_function = Trace_ELBO() else: loss_function = TraceEnum_ELBO(max_plate_nesting=1) # Set up the inference process. svi = SVI(model.model, model.guide, scheduler, loss=loss_function) # Run training. run_training(model, svi, train_loader, test_loader, epochs=args.epochs, test_freq=5) return model
def test_model(model, guide, loss): pyro.clear_param_store() loss.loss(model, guide)
def main_sVAE(arr): X_DIM = 10000 Y_DIM = 2 Z_DIM = 16 ALPHA_ENCO = int("".join(str(i) for i in arr[0:10]), 2) BETA_ENCO = int("".join(str(i) for i in arr[10:18]), 2) H_DIM_ENCO_1 = ALPHA_ENCO + BETA_ENCO H_DIM_ENCO_2 = ALPHA_ENCO H_DIM_DECO_1 = ALPHA_ENCO H_DIM_DECO_2 = ALPHA_ENCO + BETA_ENCO print(str(H_DIM_ENCO_1)) print(str(H_DIM_ENCO_2)) print(str(H_DIM_DECO_1)) print(str(H_DIM_DECO_2)) print('-----------') # Run options LEARNING_RATE = 1.0e-3 USE_CUDA = True # Run only for a single iteration for testing NUM_EPOCHS = 501 TEST_FREQUENCY = 5 train_loader, test_loader = dataloader_first() # clear param store pyro.clear_param_store() # setup the VAE vae = VAE(x_dim=X_DIM, y_dim=Y_DIM, h_dim_enco_1=H_DIM_ENCO_1, h_dim_enco_2=H_DIM_ENCO_2, h_dim_deco_1=H_DIM_DECO_1, h_dim_deco_2=H_DIM_DECO_1, z_dim=Z_DIM, use_cuda=USE_CUDA) # setup the optimizer adagrad_params = {"lr": 0.00003} optimizer = Adagrad(adagrad_params) svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO()) train_elbo = [] test_elbo = [] # training loop for epoch in range(NUM_EPOCHS): total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA) train_elbo.append(-total_epoch_loss_train) print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train)) if epoch == 500: # --------------------------Do testing for each epoch here-------------------------------- # initialize loss accumulator test_loss = 0. # compute the loss over the entire test set for x_test, y_test in test_loader: x_test = x_test.cuda() y_test = y_test.cuda() # compute ELBO estimate and accumulate loss labels_y_test = torch.tensor(np.zeros((y_test.shape[0], 2))) y_test_2 = torch.Tensor.cpu( y_test.reshape(1, y_test.size()[0])[0]).numpy().astype(int) labels_y_test = np.eye(2)[y_test_2] labels_y_test = torch.from_numpy(labels_y_test) test_loss += svi.evaluate_loss( x_test.reshape(-1, 10000), labels_y_test.cuda().float() ) #Data entry point <---------------------------------Data Entry Point normalizer_test = len(test_loader.dataset) total_epoch_loss_test = test_loss / normalizer_test print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_test)) return total_epoch_loss_test
def main(args): # clear param store pyro.clear_param_store() # setup MNIST data loaders # train_loader, test_loader train_loader, test_loader = setup_data_loaders(MNIST, use_cuda=args.cuda, batch_size=256) # setup the VAE vae = VAE(use_cuda=args.cuda) # setup the optimizer adam_args = {"lr": args.learning_rate} optimizer = Adam(adam_args) # setup the inference algorithm elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(vae.model, vae.guide, optimizer, loss=elbo) # setup visdom for visualization if args.visdom_flag: vis = visdom.Visdom() train_elbo = [] test_elbo = [] # training loop for epoch in range(args.num_epochs): # initialize loss accumulator epoch_loss = 0. # do a training epoch over each mini-batch x returned # by the data loader for x, _ in train_loader: # if on GPU put mini-batch into CUDA memory if args.cuda: x = x.cuda() # do ELBO gradient and accumulate loss epoch_loss += svi.step(x) # report training diagnostics normalizer_train = len(train_loader.dataset) total_epoch_loss_train = epoch_loss / normalizer_train train_elbo.append(total_epoch_loss_train) print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train)) if epoch % args.test_frequency == 0: # initialize loss accumulator test_loss = 0. # compute the loss over the entire test set for i, (x, _) in enumerate(test_loader): # if on GPU put mini-batch into CUDA memory if args.cuda: x = x.cuda() # compute ELBO estimate and accumulate loss test_loss += svi.evaluate_loss(x) # pick three random test images from the first mini-batch and # visualize how well we're reconstructing them if i == 0: if args.visdom_flag: plot_vae_samples(vae, vis) reco_indices = np.random.randint(0, x.shape[0], 3) for index in reco_indices: test_img = x[index, :] reco_img = vae.reconstruct_img(test_img) vis.image(test_img.reshape( 28, 28).detach().cpu().numpy(), opts={'caption': 'test image'}) vis.image(reco_img.reshape( 28, 28).detach().cpu().numpy(), opts={'caption': 'reconstructed image'}) # report test diagnostics normalizer_test = len(test_loader.dataset) total_epoch_loss_test = test_loss / normalizer_test test_elbo.append(total_epoch_loss_test) print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test)) if epoch == args.tsne_iter: mnist_test_tsne(vae=vae, test_loader=test_loader) plot_llk(np.array(train_elbo), np.array(test_elbo)) return vae
def train(): py.clear_param_store() for j in range(num_iterations): loss = svi.step(x_data, y_data) if j % 100 == 0: print("Iteration %04d loss: %4f" % (j + 1, loss / len(data)))
def main(): pyro.clear_param_store() #pyro.get_param_store().load('Pyro_model') for j in range(n_epochs): loss = 0 start = time.time() for data in train_loader: data[0] = Variable(data[0].cuda()) #.view(-1, 28 * 28).cuda()) data[1] = Variable(data[1].long().cuda()) loss += svi.step(data) print(time.time() - start) #if j % 100 == 0: print("[iteration %04d] loss: %.4f" % (j + 1, loss / float(n_train_batches * batch_size))) #for name in pyro.get_param_store().get_all_param_names(): # print("[%s]: %.3f" % (name, pyro.param(name).data.numpy())) pyro.get_param_store().save('Pyro_model') datasets = {'RegularImages_0.0': [test.test_data, test.test_labels]} fgsm = glob.glob('fgsm/fgsm_cifar10_examples_x_10000_*' ) #glob.glob('fgsm/fgsm_mnist_adv_x_1000_*') fgsm_labels = test.test_labels #torch.from_numpy(np.argmax(np.load('fgsm/fgsm_mnist_adv_y_1000.npy'), axis=1)) for file in fgsm: parts = file.split('_') key = parts[0].split('/')[0] + '_' + parts[-1].split('.npy')[0] datasets[key] = [torch.from_numpy(np.load(file)), fgsm_labels] #jsma = glob.glob('jsma/jsma_cifar10_adv_x_10000*') #jsma_labels = torch.from_numpy(np.argmax(np.load('jsma/jsma_cifar10_adv_y_10000.npy'), axis=1)) #for file in jsma: # parts = file.split('_') # key = parts[0].split('/')[0] + '_' + parts[-1].split('.npy')[0] # datasets[key] = [torch.from_numpy(np.load(file)), jsma_labels] gaussian = glob.glob('gaussian/cifar_gaussian_adv_x_*') gaussian_labels = torch.from_numpy( np.argmax(np.load('gaussian/cifar_gaussian_adv_y.npy')[0:1000], axis=1)) for file in gaussian: parts = file.split('_') key = parts[0].split('/')[0] + '_' + parts[-1].split('.npy')[0] datasets[key] = [torch.from_numpy(np.load(file)), gaussian_labels] print(datasets.keys()) print( '################################################################################' ) accuracies = {} for key, value in datasets.iteritems(): print(key) parts = key.split('_') adversary_type = parts[0] epsilon = parts[1] data = value X, y = data[0], data[1] #.view(-1, 28 * 28), data[1] x_data, y_data = Variable(X.float().cuda()), Variable(y.cuda()) T = 100 accs = [] samples = np.zeros((y_data.data.size()[0], T, outputs)) for i in range(T): sampled_model = guide(None) pred = sampled_model(x_data) samples[:, i, :] = pred.data.cpu().numpy() _, out = torch.max(pred, 1) acc = np.count_nonzero( np.squeeze(out.data.cpu().numpy()) == np.int32(y_data.data.cpu( ).numpy().ravel())) / float(y_data.data.size()[0]) accs.append(acc) variationRatio = [] mutualInformation = [] predictiveEntropy = [] predictions = [] for i in range(0, len(y_data)): entry = samples[i, :, :] variationRatio.append(Uncertainty.variation_ratio(entry)) mutualInformation.append(Uncertainty.mutual_information(entry)) predictiveEntropy.append(Uncertainty.predictive_entropy(entry)) predictions.append(np.max(entry.mean(axis=0), axis=0)) uncertainty = {} uncertainty['varation_ratio'] = np.array(variationRatio) uncertainty['predictive_entropy'] = np.array(predictiveEntropy) uncertainty['mutual_information'] = np.array(mutualInformation) predictions = np.array(predictions) Uncertainty.plot_uncertainty(uncertainty, predictions, adversarial_type=adversary_type, epsilon=float(epsilon), directory='Results_CIFAR10_PYRO') #, directory='Results_CIFAR10_PYRO') accs = np.array(accs) print('Accuracy mean: {}, Accuracy std: {}'.format( accs.mean(), accs.std())) #accuracies[key] = {'mean': accs.mean(), 'std': accs.std()} accuracies[key] = {'mean': accs.mean(), 'std': accs.std(), \ 'variationratio': [uncertainty['varation_ratio'].mean(), uncertainty['varation_ratio'].std()], \ 'predictiveEntropy': [uncertainty['predictive_entropy'].mean(), uncertainty['predictive_entropy'].std()], \ 'mutualInformation': [uncertainty['mutual_information'].mean(), uncertainty['mutual_information'].std()]} np.save('PyroBNN_accuracies_CIFAR10', accuracies)
def setup(): pyro.clear_param_store()
get_ipython().run_cell_magic( 'time', '', '\n### HMC ###\npyro.clear_param_store()\n\n# Set random seed for reproducibility.\npyro.set_rng_seed(2)\n\n# Set up HMC sampler.\nkernel = HMC(gpc, step_size=0.05, trajectory_length=1, \n adapt_step_size=False, adapt_mass_matrix=False, jit_compile=True)\nhmc = MCMC(kernel, num_samples=500, warmup_steps=500)\nhmc.run(X, y.double())\n\n# Get posterior samples\nhmc_posterior_samples = hmc.get_samples()' ) # In[59]: plot_uq(hmc_posterior_samples, X, Xnew, "HMC") # ## NUTS # In[60]: ## NUTS ### pyro.clear_param_store() pyro.set_rng_seed(2) nuts = MCMC(NUTS(gpc, target_accept_prob=0.8, max_tree_depth=10, jit_compile=True), num_samples=500, warmup_steps=500) nuts.run(X, y.double()) nuts_posterior_samples = nuts.get_samples() # In[61]: plot_uq(nuts_posterior_samples, X, Xnew, "NUTS")
def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False): pyro.clear_param_store() Beta = dist.Beta if reparameterized else fakes.NonreparameterizedBeta def model(): with pyro.plate("samples", self.sample_batch_size): pyro.sample( "p_latent", Beta( torch.stack([torch.stack([self.alpha0])] * self.sample_batch_size), torch.stack([torch.stack([self.beta0])] * self.sample_batch_size), ).to_event(1), ) def guide(): alpha_q_log = pyro.param("alpha_q_log", torch.log(self.alpha0) + 0.17) beta_q_log = pyro.param("beta_q_log", torch.log(self.beta0) - 0.143) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) with pyro.plate("samples", self.sample_batch_size): pyro.sample( "p_latent", Beta( torch.stack([torch.stack([alpha_q])] * self.sample_batch_size), torch.stack([torch.stack([beta_q])] * self.sample_batch_size), ).to_event(1), ) adam = optim.Adam({"lr": 0.001, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss=loss) alpha = 0.99 for k in range(n_steps): svi.step() if debug: alpha_error = param_abs_error("alpha_q_log", torch.log(self.alpha0)) beta_error = param_abs_error("beta_q_log", torch.log(self.beta0)) with torch.no_grad(): if k == 0: ( avg_loglikelihood, avg_penalty, ) = loss._differentiable_loss_parts(model, guide) avg_loglikelihood = torch_item(avg_loglikelihood) avg_penalty = torch_item(avg_penalty) loglikelihood, penalty = loss._differentiable_loss_parts( model, guide) avg_loglikelihood = alpha * avg_loglikelihood + ( 1 - alpha) * torch_item(loglikelihood) avg_penalty = alpha * avg_penalty + ( 1 - alpha) * torch_item(penalty) if k % 100 == 0: print(alpha_error, beta_error) print(avg_loglikelihood, avg_penalty) print() alpha_error = param_abs_error("alpha_q_log", torch.log(self.alpha0)) beta_error = param_abs_error("beta_q_log", torch.log(self.beta0)) assert_equal(0.0, alpha_error, prec=0.08) assert_equal(0.0, beta_error, prec=0.08)
def test_exponential_gamma(gamma_dist, n_steps, elbo_impl): pyro.clear_param_store() # gamma prior hyperparameter alpha0 = torch.tensor(1.0) # gamma prior hyperparameter beta0 = torch.tensor(1.0) n_data = 2 data = torch.tensor([3.0, 2.0]) # two observations alpha_n = alpha0 + torch.tensor(float(n_data)) # posterior alpha beta_n = beta0 + torch.sum(data) # posterior beta prec = 0.2 if gamma_dist.has_rsample else 0.25 def model(alpha0, beta0, alpha_n, beta_n): lambda_latent = pyro.sample("lambda_latent", gamma_dist(alpha0, beta0)) with pyro.plate("data", n_data): pyro.sample("obs", dist.Exponential(lambda_latent), obs=data) return lambda_latent def guide(alpha0, beta0, alpha_n, beta_n): alpha_q = pyro.param("alpha_q", alpha_n * math.exp(0.17), constraint=constraints.positive) beta_q = pyro.param("beta_q", beta_n / math.exp(0.143), constraint=constraints.positive) pyro.sample("lambda_latent", gamma_dist(alpha_q, beta_q)) adam = optim.Adam({"lr": 0.0003, "betas": (0.97, 0.999)}) if elbo_impl is RenyiELBO: elbo = elbo_impl( alpha=0.2, num_particles=3, max_plate_nesting=1, strict_enumeration_warning=False, ) elif elbo_impl is ReweightedWakeSleep: if gamma_dist is ShapeAugmentedGamma: pytest.xfail( reason= "ShapeAugmentedGamma not suported for ReweightedWakeSleep") else: elbo = elbo_impl(num_particles=3, max_plate_nesting=1, strict_enumeration_warning=False) else: elbo = elbo_impl(max_plate_nesting=1, strict_enumeration_warning=False) svi = SVI(model, guide, adam, loss=elbo) with xfail_if_not_implemented(): for k in range(n_steps): svi.step(alpha0, beta0, alpha_n, beta_n) assert_equal( pyro.param("alpha_q"), alpha_n, prec=prec, msg="{} vs {}".format( pyro.param("alpha_q").detach().cpu().numpy(), alpha_n.detach().cpu().numpy()), ) assert_equal( pyro.param("beta_q"), beta_n, prec=prec, msg="{} vs {}".format( pyro.param("beta_q").detach().cpu().numpy(), beta_n.detach().cpu().numpy()), )
def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False, lr=0.0002): pyro.clear_param_store() Gamma = dist.Gamma if reparameterized else fakes.NonreparameterizedGamma def model(): with pyro.plate("samples", self.sample_batch_size): pyro.sample( "lambda_latent", Gamma( torch.stack([torch.stack([self.alpha0])] * self.sample_batch_size), torch.stack([torch.stack([self.beta0])] * self.sample_batch_size), ).to_event(1), ) def guide(): alpha_q = pyro.param( "alpha_q", self.alpha0.detach() + math.exp(0.17), constraint=constraints.positive, ) beta_q = pyro.param( "beta_q", self.beta0.detach() / math.exp(0.143), constraint=constraints.positive, ) with pyro.plate("samples", self.sample_batch_size): pyro.sample( "lambda_latent", Gamma( torch.stack([torch.stack([alpha_q])] * self.sample_batch_size), torch.stack([torch.stack([beta_q])] * self.sample_batch_size), ).to_event(1), ) adam = optim.Adam({"lr": lr, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss) alpha = 0.99 for k in range(n_steps): svi.step() if debug: alpha_error = param_mse("alpha_q", self.alpha0) beta_error = param_mse("beta_q", self.beta0) with torch.no_grad(): if k == 0: ( avg_loglikelihood, avg_penalty, ) = loss._differentiable_loss_parts( model, guide, (), {}) avg_loglikelihood = torch_item(avg_loglikelihood) avg_penalty = torch_item(avg_penalty) loglikelihood, penalty = loss._differentiable_loss_parts( model, guide, (), {}) avg_loglikelihood = alpha * avg_loglikelihood + ( 1 - alpha) * torch_item(loglikelihood) avg_penalty = alpha * avg_penalty + ( 1 - alpha) * torch_item(penalty) if k % 100 == 0: print(alpha_error, beta_error) print(avg_loglikelihood, avg_penalty) print() assert_equal( pyro.param("alpha_q"), self.alpha0, prec=0.2, msg="{} vs {}".format( pyro.param("alpha_q").detach().cpu().numpy(), self.alpha0.detach().cpu().numpy(), ), ) assert_equal( pyro.param("beta_q"), self.beta0, prec=0.15, msg="{} vs {}".format( pyro.param("beta_q").detach().cpu().numpy(), self.beta0.detach().cpu().numpy(), ), )
def do_elbo_test( self, repa1, repa2, n_steps, prec, lr, use_nn_baseline, use_decaying_avg_baseline, ): logger.info(" - - - - - DO NORMALNORMALNORMAL ELBO TEST - - - - - -") logger.info( "[reparameterized = %s, %s; nn_baseline = %s, decaying_baseline = %s]" % (repa1, repa2, use_nn_baseline, use_decaying_avg_baseline)) pyro.clear_param_store() Normal1 = dist.Normal if repa1 else fakes.NonreparameterizedNormal Normal2 = dist.Normal if repa2 else fakes.NonreparameterizedNormal if use_nn_baseline: class VanillaBaselineNN(nn.Module): def __init__(self, dim_input, dim_h): super().__init__() self.lin1 = nn.Linear(dim_input, dim_h) self.lin2 = nn.Linear(dim_h, 2) self.sigmoid = nn.Sigmoid() def forward(self, x): h = self.sigmoid(self.lin1(x)) return self.lin2(h) loc_prime_baseline = pyro.module("loc_prime_baseline", VanillaBaselineNN(2, 5)) else: loc_prime_baseline = None def model(): with pyro.plate("plate", 2): loc_latent_prime = pyro.sample( "loc_latent_prime", Normal1(self.loc0, torch.pow(self.lam0, -0.5))) loc_latent = pyro.sample( "loc_latent", Normal2(loc_latent_prime, torch.pow(self.lam0, -0.5))) with pyro.plate("data", len(self.data)): pyro.sample( "obs", dist.Normal(loc_latent, torch.pow( self.lam, -0.5)).expand_by(self.data.shape[:1]), obs=self.data, ) return loc_latent # note that the exact posterior is not mean field! def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.334) log_sig_q = pyro.param("log_sig_q", self.analytic_log_sig_n.expand(2) - 0.29) loc_q_prime = pyro.param("loc_q_prime", torch.tensor([-0.34, 0.52])) kappa_q = pyro.param("kappa_q", torch.tensor([0.74])) log_sig_q_prime = pyro.param("log_sig_q_prime", -0.5 * torch.log(1.2 * self.lam0)) sig_q, sig_q_prime = torch.exp(log_sig_q), torch.exp( log_sig_q_prime) with pyro.plate("plate", 2): loc_latent = pyro.sample( "loc_latent", Normal2(loc_q, sig_q), infer=dict(baseline=dict( use_decaying_avg_baseline=use_decaying_avg_baseline)), ) pyro.sample( "loc_latent_prime", Normal1( kappa_q.expand_as(loc_latent) * loc_latent + loc_q_prime, sig_q_prime, ), infer=dict(baseline=dict( nn_baseline=loc_prime_baseline, nn_baseline_input=loc_latent, use_decaying_avg_baseline=use_decaying_avg_baseline, )), ) with pyro.plate("data", len(self.data)): pass return loc_latent adam = optim.Adam({"lr": 0.0015, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for k in range(n_steps): svi.step() loc_error = param_mse("loc_q", self.analytic_loc_n) log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n) loc_prime_error = param_mse("loc_q_prime", 0.5 * self.loc0) kappa_error = param_mse("kappa_q", 0.5 * torch.ones(1)) log_sig_prime_error = param_mse("log_sig_q_prime", -0.5 * torch.log(2.0 * self.lam0)) if k % 500 == 0: logger.debug("errors: %.4f, %.4f" % (loc_error, log_sig_error)) logger.debug(", %.4f, %.4f" % (loc_prime_error, log_sig_prime_error)) logger.debug(", %.4f" % kappa_error) assert_equal(0.0, loc_error, prec=prec) assert_equal(0.0, log_sig_error, prec=prec) assert_equal(0.0, loc_prime_error, prec=prec) assert_equal(0.0, log_sig_prime_error, prec=prec) assert_equal(0.0, kappa_error, prec=prec)
def update_noise_svi(self, obs_data, intervened_model=None): """ Use svi to find out the mu, sigma of the distributionsfor the condition outlined in obs_data """ def guide(noise): """ The guide serves as an approximation to the posterior p(z|x). The guide provides a valid joint probability density over all the latent random variables in the model. https://pyro.ai/examples/svi_part_i.html """ # create params with constraints mu = { 'N_X': pyro.param('N_X_mu', 0.5*torch.ones(self.image_dim),constraint = constraints.interval(0., 1.)), 'N_Z': pyro.param('N_Z_mu', torch.zeros(self.z_dim),constraint = constraints.interval(-3., 3.)), 'N_Y_1': pyro.param('N_Y_1_mu', 0.5*torch.ones(self.label_dims[1]),constraint = constraints.interval(0., 1.)), 'N_Y_2': pyro.param('N_Y_2_mu', 0.5*torch.ones(self.label_dims[2]),constraint = constraints.interval(0., 1.)), 'N_Y_3': pyro.param('N_Y_3_mu', 0.5*torch.ones(self.label_dims[3]),constraint = constraints.interval(0., 1.)), 'N_Y_4': pyro.param('N_Y_4_mu', 0.5*torch.ones(self.label_dims[4]),constraint = constraints.interval(0., 1.)), 'N_Y_5': pyro.param('N_Y_5_mu', 0.5*torch.ones(self.label_dims[5]),constraint = constraints.interval(0., 1.)) } sigma = { 'N_X': pyro.param('N_X_sigma', 0.1*torch.ones(self.image_dim),constraint = constraints.interval(0.0001, 0.5)), 'N_Z': pyro.param('N_Z_sigma', torch.ones(self.z_dim),constraint = constraints.interval(0.0001, 3.)), 'N_Y_1': pyro.param('N_Y_1_sigma', 0.1*torch.ones(self.label_dims[1]),constraint = constraints.interval(0.0001, 0.5)), 'N_Y_2': pyro.param('N_Y_2_sigma', 0.1*torch.ones(self.label_dims[2]),constraint = constraints.interval(0.0001, 0.5)), 'N_Y_3': pyro.param('N_Y_3_sigma', 0.1*torch.ones(self.label_dims[3]),constraint = constraints.interval(0.0001, 0.5)), 'N_Y_4': pyro.param('N_Y_4_sigma', 0.1*torch.ones(self.label_dims[4]),constraint = constraints.interval(0.0001, 0.5)), 'N_Y_5': pyro.param('N_Y_5_sigma', 0.1*torch.ones(self.label_dims[5]),constraint = constraints.interval(0.0001, 0.5)) } for noise_term in noise.keys(): pyro.sample(noise_term, dist.Normal(mu[noise_term], sigma[noise_term]).to_event(1)) # Condition the model if intervened_model is not None: obs_model = pyro.condition(intervened_model, obs_data) else: obs_model = pyro.condition(self.model, obs_data) pyro.clear_param_store() # Once we’ve specified a guide, we’re ready to proceed to inference. # Now, this an optimization problem where each iteration of training takes # a step that moves the guide closer to the exact posterior # https://arxiv.org/pdf/1601.00670.pdf svi = SVI( model= obs_model, guide= guide, optim= SGD({"lr": 1e-5, 'momentum': 0.1}), loss=Trace_ELBO(retain_graph=True) ) num_steps = 1500 samples = defaultdict(list) for t in range(num_steps): loss = svi.step(self.init_noise) # if t % 100 == 0: # print("step %d: loss of %.2f" % (t, loss)) for noise in self.init_noise.keys(): mu = '{}_mu'.format(noise) sigma = '{}_sigma'.format(noise) samples[mu].append(pyro.param(mu).detach().numpy()) samples[sigma].append(pyro.param(sigma).detach().numpy()) means = {k: torch.tensor(np.array(v).mean(axis=0)) for k, v in samples.items()} # update the inferred noise updated_noise = { 'N_X' : dist.Normal(means['N_X_mu'], means['N_X_sigma']), 'N_Z' : dist.Normal(means['N_Z_mu'], means['N_Z_sigma']), 'N_Y_1': dist.Normal(means['N_Y_1_mu'], means['N_Y_1_sigma']), 'N_Y_2': dist.Normal(means['N_Y_2_mu'], means['N_Y_2_sigma']), 'N_Y_3': dist.Normal(means['N_Y_3_mu'], means['N_Y_3_sigma']), 'N_Y_4': dist.Normal(means['N_Y_4_mu'], means['N_Y_4_sigma']), 'N_Y_5': dist.Normal(means['N_Y_5_mu'], means['N_Y_5_sigma']), } return updated_noise
def _test_plate_in_elbo(self, n_superfluous_top, n_superfluous_bottom, n_steps, lr=0.0012): pyro.clear_param_store() self.data_tensor = torch.zeros(9, 2) for _out in range(self.n_outer): for _in in range(self.n_inner): self.data_tensor[3 * _out + _in, :] = self.data[_out][_in] self.data_as_list = [ self.data_tensor[0:4, :], self.data_tensor[4:7, :], self.data_tensor[7:9, :], ] def model(): loc_latent = pyro.sample( "loc_latent", fakes.NonreparameterizedNormal(self.loc0, torch.pow(self.lam0, -0.5)).to_event(1), ) for i in pyro.plate("outer", 3): x_i = self.data_as_list[i] with pyro.plate("inner_%d" % i, x_i.size(0)): for k in range(n_superfluous_top): z_i_k = pyro.sample( "z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(0, 1).expand_by( [4 - i]), ) assert z_i_k.shape == (4 - i, ) obs_i = pyro.sample( "obs_%d" % i, dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).to_event(1), obs=x_i, ) assert obs_i.shape == (4 - i, 2) for k in range(n_superfluous_top, n_superfluous_top + n_superfluous_bottom): z_i_k = pyro.sample( "z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(0, 1).expand_by( [4 - i]), ) assert z_i_k.shape == (4 - i, ) pt_loc_baseline = torch.nn.Linear(1, 1) pt_superfluous_baselines = [] for k in range(n_superfluous_top + n_superfluous_bottom): pt_superfluous_baselines.extend([ torch.nn.Linear(2, 4), torch.nn.Linear(2, 3), torch.nn.Linear(2, 2) ]) def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.094) log_sig_q = pyro.param("log_sig_q", self.analytic_log_sig_n.expand(2) - 0.07) sig_q = torch.exp(log_sig_q) trivial_baseline = pyro.module("loc_baseline", pt_loc_baseline) baseline_value = trivial_baseline(torch.ones(1)).squeeze() loc_latent = pyro.sample( "loc_latent", fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1), infer=dict(baseline=dict(baseline_value=baseline_value)), ) for i in pyro.plate("outer", 3): with pyro.plate("inner_%d" % i, 4 - i): for k in range(n_superfluous_top + n_superfluous_bottom): z_baseline = pyro.module( "z_baseline_%d_%d" % (i, k), pt_superfluous_baselines[3 * k + i], ) baseline_value = z_baseline(loc_latent.detach()) mean_i = pyro.param("mean_%d_%d" % (i, k), 0.5 * torch.ones(4 - i)) z_i_k = pyro.sample( "z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(mean_i, 1), infer=dict(baseline=dict( baseline_value=baseline_value)), ) assert z_i_k.shape == (4 - i, ) def per_param_callable(param_name): if "baseline" in param_name: return {"lr": 0.010, "betas": (0.95, 0.999)} else: return {"lr": lr, "betas": (0.95, 0.999)} adam = optim.Adam(per_param_callable) svi = SVI(model, guide, adam, loss=TraceGraph_ELBO()) for step in range(n_steps): svi.step() loc_error = param_abs_error("loc_q", self.analytic_loc_n) log_sig_error = param_abs_error("log_sig_q", self.analytic_log_sig_n) if n_superfluous_top > 0 or n_superfluous_bottom > 0: superfluous_errors = [] for k in range(n_superfluous_top + n_superfluous_bottom): mean_0_error = torch.sum( torch.pow(pyro.param("mean_0_%d" % k), 2.0)) mean_1_error = torch.sum( torch.pow(pyro.param("mean_1_%d" % k), 2.0)) mean_2_error = torch.sum( torch.pow(pyro.param("mean_2_%d" % k), 2.0)) superfluous_error = torch.max( torch.max(mean_0_error, mean_1_error), mean_2_error) superfluous_errors.append( superfluous_error.detach().cpu().numpy()) if step % 500 == 0: logger.debug("loc error, log(scale) error: %.4f, %.4f" % (loc_error, log_sig_error)) if n_superfluous_top > 0 or n_superfluous_bottom > 0: logger.debug("superfluous error: %.4f" % np.max(superfluous_errors)) assert_equal(0.0, loc_error, prec=0.04) assert_equal(0.0, log_sig_error, prec=0.05) if n_superfluous_top > 0 or n_superfluous_bottom > 0: assert_equal(0.0, np.max(superfluous_errors), prec=0.04)
def test_model(model, guide, loss, x_data, y_data): pyro.clear_param_store() loss.loss(model, guide, x_data, y_data)
""""" # Run options LEARNING_RATE = 1.0e-3 USE_CUDA = False # Run only for a single iteration for testing NUM_EPOCHS = 1 if smoke_test else 100 TEST_FREQUENCY = 5 # Get data train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA) pyro.clear_param_store() # Initialize instance of the VAE class vae = VAE() # Setup Adam optimizer (an algorithm for first-order gradient-based optimization) optimizer = Adam({"lr": 1.0e-3}) # SVI: stochastic variational inference - a scalable algorithm for approximating posterior distributions. # Trace_ELBO: top-level interface for stochastic variational inference via optimization of the evidence lower bound. svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO()) #set_trace() train_elbo = [] test_elbo = [] # Training loop for epoch in range(NUM_EPOCHS): print("1") total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA) #print("2") train_elbo.append(-total_epoch_loss_train) print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train)) if epoch % TEST_FREQUENCY == 0: # report test diagnostics total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA) test_elbo.append(-total_epoch_loss_test) print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test)) """ ""
def pytest_runtest_setup(item): # pylint: disable=missing-function-docstring pyro.clear_param_store() reset_state() if item.get_closest_marker("fix_rng") is not None: torch.manual_seed(0)
def main(args): # clear param store pyro.clear_param_store() ### SETUP train_loader, test_loader = get_data() # setup the VAE vae = VAE(use_cuda=args.cuda) # setup the optimizer adam_args = {"lr": args.learning_rate} optimizer = Adam(adam_args) # setup the inference algorithm elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(vae.model, vae.guide, optimizer, loss=elbo) inputSize = 0 # setup visdom for visualization if args.visdom_flag: vis = visdom.Visdom() train_elbo = [] test_elbo = [] for epoch in range(args.num_epochs): # initialize loss accumulator epoch_loss = 0. # do a training epoch over each mini-batch x returned # by the data loader for step, batch in enumerate(train_loader): x, adj = 0, 0 # if on GPU put mini-batch into CUDA memory if args.cuda: x = batch['x'].cuda() adj = batch['edge_index'].cuda() else: x = batch['x'] adj = batch['edge_index'] print("x_shape", x.shape) print("adj_shape", adj.shape) inputSize = x.shape[0] * x.shape[1] epoch_loss += svi.step(x, adj) # report training diagnostics normalizer_train = len(train_loader.dataset) total_epoch_loss_train = epoch_loss / normalizer_train train_elbo.append(total_epoch_loss_train) print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train)) if True: # if epoch % args.test_frequency == 0: # initialize loss accumulator test_loss = 0. # compute the loss over the entire test set for step, batch in enumerate(test_loader): x, adj = 0, 0 # if on GPU put mini-batch into CUDA memory if args.cuda: x = batch['x'].cuda() adj = batch['edge_index'].cuda() else: x = batch['x'] adj = batch['edge_index'] # compute ELBO estimate and accumulate loss # print('before evaluating test loss') test_loss += svi.evaluate_loss(x, adj) # print('after evaluating test loss') # pick three random test images from the first mini-batch and # visualize how well we're reconstructing them # if i == 0: # if args.visdom_flag: # plot_vae_samples(vae, vis) # reco_indices = np.random.randint(0, x.shape[0], 3) # for index in reco_indices: # test_img = x[index, :] # reco_img = vae.reconstruct_img(test_img) # vis.image(test_img.reshape(28, 28).detach().cpu().numpy(), # opts={'caption': 'test image'}) # vis.image(reco_img.reshape(28, 28).detach().cpu().numpy(), # opts={'caption': 'reconstructed image'}) if args.visdom_flag: plot_vae_samples(vae, vis) reco_indices = np.random.randint(0, x.shape[0], 3) for index in reco_indices: test_img = x[index, :] reco_img = vae.reconstruct_graph(test_img) vis.image(test_img.reshape(28, 28).detach().cpu().numpy(), opts={'caption': 'test image'}) vis.image(reco_img.reshape(28, 28).detach().cpu().numpy(), opts={'caption': 'reconstructed image'}) # report test diagnostics normalizer_test = len(test_loader.dataset) total_epoch_loss_test = test_loss / normalizer_test test_elbo.append(total_epoch_loss_test) print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test)) # if epoch == args.tsne_iter: # mnist_test_tsne(vae=vae, test_loader=test_loader) # plot_llk(np.array(train_elbo), np.array(test_elbo)) if args.save: torch.save( { 'epoch': epoch, 'model_state_dict': vae.state_dict(), 'optimzier_state_dict': optimizer.get_state(), 'train_loss': total_epoch_loss_train, 'test_loss': total_epoch_loss_test }, 'vae_' + args.name + str(args.time) + '.pt') return vae
def rejection_sample_feasible_tree(num_attempts=999): ''' Repeatedly samples trees from the grammar until one satisfies some hand-coded constraints. This will be simplified when constraint specification and sampling machinery is generalized. For now, this is hard-coded to work for the kitchen example. ''' for attempt_k in range(num_attempts): start = time.time() pyro.clear_param_store() scene_tree = ParseTree.generate_from_root_type(root_node_type=Kitchen) end = time.time() print("Generated tree in %f seconds." % (end - start)) # Enforce that there are > 0 cabinets num_cabinets = len([node for node in scene_tree.nodes if isinstance(node, Cabinet)]) if num_cabinets != 1: continue # Enforce that there are at least 2 objects on the table tables = scene_tree.find_nodes_by_type(Table) table_children = sum([scene_tree.get_recursive_children_of_node(node) for node in tables], []) num_objects_on_tables = len([node for node in table_children if isinstance(node, KitchenObject)]) print("Num objs on table: ", num_objects_on_tables) if num_objects_on_tables < 5: continue # Enforce that there are at least 2 objects in cabinets #cabinets = scene_tree.find_nodes_by_type(Cabinet) #table_children = sum([scene_tree.get_recursive_children_of_node(node) for node in cabinets], []) #num_objects_in_cabinets = len([node for node in table_children if isinstance(node, KitchenObject)]) #print("Num objs in cabinets: ", num_objects_in_cabinets) #if num_objects_in_cabinets < 2: # continue # Do Collision checking on the clearance geometry, and reject # scenes where the collision geometry is in collision. # (This could be done at subtree level, and eventually I'll do that -- # but for this scene it doesn't matter b/c clearance geometry is all # furniture level anyway. # TODO: What if I did rejection sampling for nonpenetration at the # container level? Is that legit as a sampling strategy?) builder_clearance, mbp_clearance, sg_clearance = \ compile_scene_tree_clearance_geometry_to_mbp_and_sg(scene_tree) mbp_clearance.Finalize() diagram_clearance = builder_clearance.Build() diagram_context = diagram_clearance.CreateDefaultContext() mbp_context = diagram_clearance.GetMutableSubsystemContext(mbp_clearance, diagram_context) constraint = build_clearance_nonpenetration_constraint( mbp_clearance, mbp_context, -0.01) constraint.Eval(mbp_clearance.GetPositions(mbp_context)) q0 = mbp_clearance.GetPositions(mbp_context) print("CONSTRAINT EVAL: %f <= %f <= %f" % ( constraint.lower_bound(), constraint.Eval(mbp_clearance.GetPositions(mbp_context)), constraint.upper_bound())) print(len(get_collisions(mbp_clearance, mbp_context)), " bodies in collision") # We can draw clearance geometry for debugging. #draw_clearance_geometry_meshcat(scene_tree, alpha=0.3) # If we failed the initial clearance check, resample. if not constraint.CheckSatisfied(q0): continue # Good solution! return scene_tree, True # Bad solution :( return scene_tree, False
def train(num_iterations, svi): pyro.clear_param_store() for j in tqdm(range(num_iterations)): loss = svi.step(data) losses.append(loss)
def train(args, DATA_PATH): # clear param store pyro.clear_param_store() #pyro.enable_validation(True) # train_loader, test_loader transform = {} transform["train"] = transforms.Compose([ transforms.Resize((400, 400)), transforms.ToTensor(), ]) transform["test"] = transforms.Compose( [transforms.Resize((400, 400)), transforms.ToTensor()]) train_loader, test_loader = setup_data_loaders( dataset=GameCharacterFullData, root_path=DATA_PATH, batch_size=32, transforms=transform) # setup the VAE vae = VAE(use_cuda=args.cuda, num_labels=17) # setup the exponential learning rate scheduler optimizer = torch.optim.Adam scheduler = pyro.optim.ExponentialLR({ 'optimizer': optimizer, 'optim_args': { 'lr': args.learning_rate }, 'gamma': 0.1 }) # setup the inference algorithm elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() svi = SVI(vae.model, vae.guide, scheduler, loss=elbo) # setup visdom for visualization if args.visdom_flag: vis = visdom.Visdom(port='8097') train_elbo = [] test_elbo = [] # training loop for epoch in range(args.num_epochs): # initialize loss accumulator epoch_loss = 0. # do a training epoch over each mini-batch x returned # by the data loader for x, y, actor, reactor, actor_type, reactor_type, action, reaction in train_loader: # if on GPU put mini-batch into CUDA memory if args.cuda: x = x.cuda() y = y.cuda() actor = actor.cuda() reactor = reactor.cuda() actor_type = actor_type.cuda() reactor_type = reactor_type.cuda() action = action.cuda() reaction = reaction.cuda() # do ELBO gradient and accumulate loss epoch_loss += svi.step(x, y, actor, reactor, actor_type, reactor_type, action, reaction) # report training diagnostics normalizer_train = len(train_loader.dataset) total_epoch_loss_train = epoch_loss / normalizer_train train_elbo.append(total_epoch_loss_train) print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train)) if epoch % args.test_frequency == 0: # initialize loss accumulator test_loss = 0. # compute the loss over the entire test set for i, (x, y, actor, reactor, actor_type, reactor_type, action, reaction) in enumerate(test_loader): # if on GPU put mini-batch into CUDA memory if args.cuda: x = x.cuda() y = y.cuda() actor = actor.cuda() reactor = reactor.cuda() actor_type = actor_type.cuda() reactor_type = reactor_type.cuda() action = action.cuda() reaction = reaction.cuda() # compute ELBO estimate and accumulate loss test_loss += svi.evaluate_loss(x, y, actor, reactor, actor_type, reactor_type, action, reaction) # pick three random test images from the first mini-batch and # visualize how well we're reconstructing them if i == 0: if args.visdom_flag: plot_vae_samples(vae, vis) reco_indices = np.random.randint(0, x.shape[0], 3) for index in reco_indices: test_img = x[index, :] reco_img = vae.reconstruct_img(test_img) vis.image(test_img.reshape( 400, 400).detach().cpu().numpy(), opts={'caption': 'test image'}) vis.image(reco_img.reshape( 400, 400).detach().cpu().numpy(), opts={'caption': 'reconstructed image'}) # report test diagnostics normalizer_test = len(test_loader.dataset) total_epoch_loss_test = test_loss / normalizer_test test_elbo.append(total_epoch_loss_test) print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test)) return vae, optimizer
def setUp(self): pyro.clear_param_store() def loc1_prior(tensor, *args, **kwargs): flat_tensor = tensor.view(-1) m = torch.zeros(flat_tensor.size(0)) s = torch.ones(flat_tensor.size(0)) return Normal(m, s).sample().view(tensor.size()) def scale1_prior(tensor, *args, **kwargs): flat_tensor = tensor.view(-1) m = torch.zeros(flat_tensor.size(0)) s = torch.ones(flat_tensor.size(0)) return Normal(m, s).sample().view(tensor.size()).exp() def loc2_prior(tensor, *args, **kwargs): flat_tensor = tensor.view(-1) m = torch.zeros(flat_tensor.size(0)) return Bernoulli(m).sample().view(tensor.size()) def scale2_prior(tensor, *args, **kwargs): return scale1_prior(tensor) def bias_prior(tensor, *args, **kwargs): return loc2_prior(tensor) def weight_prior(tensor, *args, **kwargs): return scale1_prior(tensor) def stoch_fn(tensor, *args, **kwargs): loc = torch.zeros(tensor.size()) scale = torch.ones(tensor.size()) return pyro.sample("sample", Normal(loc, scale)) def guide(): loc1 = pyro.param("loc1", torch.randn(2, requires_grad=True)) scale1 = pyro.param("scale1", torch.ones(2, requires_grad=True)) pyro.sample("latent1", Normal(loc1, scale1)) loc2 = pyro.param("loc2", torch.randn(2, requires_grad=True)) scale2 = pyro.param("scale2", torch.ones(2, requires_grad=True)) latent2 = pyro.sample("latent2", Normal(loc2, scale2)) return latent2 def dup_param_guide(): a = pyro.param("loc") b = pyro.param("loc") assert a == b self.model = Model() self.guide = guide self.dup_param_guide = dup_param_guide self.prior = scale1_prior self.prior_dict = { "loc1": loc1_prior, "scale1": scale1_prior, "loc2": loc2_prior, "scale2": scale2_prior } self.partial_dict = {"loc1": loc1_prior, "scale1": scale1_prior} self.nn_prior = {"fc.bias": bias_prior, "fc.weight": weight_prior} self.fn = stoch_fn self.data = torch.randn(2, 2)
def svi(data, assets, iter, num_samples, seed, autoguide=None, optim=None, subsample_size=None): assert type(data) == dict assert type(assets) == Assets assert type(iter) == int assert type(num_samples) == int assert seed is None or type(seed) == int assert autoguide is None or callable(autoguide) N = next(data.values().__iter__()).shape[0] assert all(arr.shape[0] == N for arr in data.values()) assert (subsample_size is None or type(subsample_size) == int and 0 < subsample_size < N) # TODO: Fix that this interface doesn't work for # `AutoLaplaceApproximation`, which requires different functions # to be used for optimisation / collecting samples. autoguide = AutoMultivariateNormal if autoguide is None else autoguide optim = Adam({'lr': 1e-3}) if optim is None else optim guide = autoguide(assets.fn) svi = SVI(assets.fn, guide, optim, loss=Trace_ELBO()) pyro.clear_param_store() t0 = time.time() max_iter_str_width = len(str(iter)) max_out_len = 0 with seed_ctx_mgr(seed): for i in range(iter): if subsample_size is None: dfN = None subsample = None data_for_step = data else: dfN = N subsample = torch.randint(0, N, (subsample_size, )).long() data_for_step = { k: get_mini_batch(arr, subsample) for k, arr in data.items() } loss = svi.step(dfN=dfN, subsample=subsample, **data_for_step) t1 = time.time() if t1 - t0 > 0.5 or (i + 1) == iter: iter_str = str(i + 1).rjust(max_iter_str_width) out = 'iter: {} | loss: {:.3f}'.format(iter_str, loss) max_out_len = max(max_out_len, len(out)) # Sending the ANSI code to clear the line doesn't seem to # work in notebooks, so instead we pad the output with # enough spaces to ensure we overwrite all previous input. print('\r{}'.format(out.ljust(max_out_len)), end='', file=stderr) t0 = t1 print(file=stderr) # We run the guide to generate traces from the (approx.) # posterior. We also run the model against those traces in order # to compute transformed parameters, such as `b`, etc. def get_model_trace(): guide_tr = poutine.trace(guide).get_trace() model_tr = poutine.trace(poutine.replay( assets.fn, trace=guide_tr)).get_trace(mode='prior_only', **data) return model_tr # Represent the posterior as a bunch of samples, ignoring the # possibility that we might plausibly be able to figure out e.g. # posterior maginals from the variational parameters. samples = [get_model_trace() for _ in range(num_samples)] # Unlike the NUTS case, we don't eagerly compute `mu` (for the # data set used for inference) when building `Samples#raw_samples`. # (This is because it's possible that N is very large since we # support subsampling.) Therefore `loc` always computes `mu` from # the data and the samples here. def loc(d): return location(assets.fn, samples, d) return Samples(samples, partial(get_param, samples), loc)
def test(): parser = argparse.ArgumentParser(description='Train VAE.') parser.add_argument('-c', '--config', help='Config file.') args = parser.parse_args() print(args) c = json.load(open(args.config)) print(c) # clear param store pyro.clear_param_store() # batch_size = 64 # root_dir = r'D:\projects\trading\mlbootcamp\tickers' # series_length = 60 lookback = 50 # 160 input_dim = 1 test_start_date = datetime.strptime( c['test_start_date'], '%Y/%m/%d') if c['test_start_date'] else None test_end_date = datetime.strptime( c['test_end_date'], '%Y/%m/%d') if c['test_end_date'] else None min_sequence_length_test = 2 * (c['series_length'] + lookback) max_n_files = None out_path = Path(c['out_dir']) out_path.mkdir(exist_ok=True) # load_path = 'out_saved/checkpoint_0035.pt' dataset_test = create_ticker_dataset( c['in_dir'], c['series_length'], lookback, min_sequence_length_test, start_date=test_start_date, end_date=test_end_date, fixed_start_date=True, normalised_returns=c['normalised_returns'], max_n_files=max_n_files) test_loader = DataLoader(dataset_test, batch_size=c['batch_size'], shuffle=False, num_workers=0, drop_last=True) # N_train_data = len(dataset_train) N_test_data = len(dataset_test) # N_mini_batches = N_train_data // c['batch_size'] # N_train_time_slices = c['batch_size'] * N_mini_batches print(f'N_test_data: {N_test_data}') # setup the VAE vae = VAE(c['series_length'], z_dim=c['z_dim'], use_cuda=c['cuda']) # setup the optimizer # adam_args = {"lr": args.learning_rate} # optimizer = Adam(adam_args) # setup the inference algorithm # elbo = JitTrace_ELBO() if args.jit else Trace_ELBO() # svi = SVI(vae.model, vae.guide, optimizer, loss=elbo) if c['checkpoint_load']: checkpoint = torch.load(c['checkpoint_load']) vae.load_state_dict(checkpoint['model_state_dict']) # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if 1: find_similar(vae, test_loader, c['cuda']) # Visualise first batch. batch = next(iter(test_loader)) x = batch['series'] if c['cuda']: x = x.cuda() x = x.float() x_reconst = vae.reconstruct_img(x) x = x.cpu().numpy() x_reconst = x_reconst.cpu().detach().numpy() n = min(5, x.shape[0]) fig, axes = plt.subplots(n, 1, squeeze=False) for s in range(n): ax = axes[s, 0] ax.plot(x[s]) ax.plot(x_reconst[s]) fig.savefig(out_path / f'test_batch.png')
return site_stats # Prepare training data df = rugged_data[["cont_africa", "rugged", "rgdppc_2000"]] df = df[np.isfinite(df.rgdppc_2000)] df["rgdppc_2000"] = np.log(df["rgdppc_2000"]) train = torch.tensor(df.values, dtype=torch.float) svi = SVI(model, guide, optim.Adam({"lr": .005}), loss=Trace_ELBO(), num_samples=1000) is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2] pyro.clear_param_store() num_iters = 8000 if not smoke_test else 2 for i in range(num_iters): elbo = svi.step(is_cont_africa, ruggedness, log_gdp) if i % 500 == 0: logging.info("Elbo loss: {}".format(elbo)) posterior = svi.run(log_gdp, is_cont_africa, ruggedness) sites = ["a", "bA", "bR", "bAR", "sigma"] for site, values in summary(posterior, sites).items(): print("Site: {}".format(site)) print(values, "\n")
def main(args): if args.cuda: torch.set_default_tensor_type("torch.cuda.FloatTensor") logging.info("Loading data") data = poly.load_data(poly.JSB_CHORALES) logging.info("-" * 40) model = models[args.model] logging.info("Training {} on {} sequences".format( model.__name__, len(data["train"]["sequences"]))) sequences = data["train"]["sequences"] lengths = data["train"]["sequence_lengths"] # find all the notes that are present at least once in the training set present_notes = (sequences == 1).sum(0).sum(0) > 0 # remove notes that are never played (we remove 37/88 notes) sequences = sequences[..., present_notes] if args.truncate: lengths = lengths.clamp(max=args.truncate) sequences = sequences[:, :args.truncate] num_observations = float(lengths.sum()) pyro.set_rng_seed(args.seed) pyro.clear_param_store() # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing # out the hidden state x. This is accomplished via an automatic guide that # learns point estimates of all of our conditional probability tables, # named probs_*. guide = AutoDelta( poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))) # To help debug our tensor shapes, let's print the shape of each site's # distribution, value, and log_prob tensor. Note this information is # automatically printed on most errors inside SVI. if args.print_shapes: first_available_dim = -2 if model is model_0 else -3 guide_trace = poutine.trace(guide).get_trace( sequences, lengths, args=args, batch_size=args.batch_size) model_trace = poutine.trace( poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(sequences, lengths, args=args, batch_size=args.batch_size) logging.info(model_trace.format_shapes()) # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting. # All of our models have two plates: "data" and "tones". optim = Adam({"lr": args.learning_rate}) if args.tmc: if args.jit: raise NotImplementedError( "jit support not yet added for TraceTMC_ELBO") elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2) tmc_model = poutine.infer_config( model, lambda msg: { "num_samples": args.tmc_num_samples, "expand": False } if msg["infer"].get("enumerate", None) == "parallel" else {}, ) # noqa: E501 svi = SVI(tmc_model, guide, optim, elbo) else: Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO elbo = Elbo( max_plate_nesting=1 if model is model_0 else 2, strict_enumeration_warning=(model is not model_7), jit_options={"time_compilation": args.time_compilation}, ) svi = SVI(model, guide, optim, elbo) # We'll train on small minibatches. logging.info("Step\tLoss") for step in range(args.num_steps): loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size) logging.info("{: >5d}\t{}".format(step, loss / num_observations)) if args.jit and args.time_compilation: logging.debug("time to compile: {} s.".format( elbo._differentiable_loss.compile_time)) # We evaluate on the entire training dataset, # excluding the prior term so our results are comparable across models. train_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False) logging.info("training loss = {}".format(train_loss / num_observations)) # Finally we evaluate on the test dataset. logging.info("-" * 40) logging.info("Evaluating on {} test sequences".format( len(data["test"]["sequences"]))) sequences = data["test"]["sequences"][..., present_notes] lengths = data["test"]["sequence_lengths"] if args.truncate: lengths = lengths.clamp(max=args.truncate) num_observations = float(lengths.sum()) # note that since we removed unseen notes above (to make the problem a bit easier and for # numerical stability) this test loss may not be directly comparable to numbers # reported on this dataset elsewhere. test_loss = elbo.loss(model, guide, sequences, lengths, args=args, include_prior=False) logging.info("test loss = {}".format(test_loss / num_observations)) # We expect models with higher capacity to perform better, # but eventually overfit to the training set. capacity = sum( value.reshape(-1).size(0) for value in pyro.get_param_store().values()) logging.info("{} capacity = {} parameters".format(model.__name__, capacity))
def run_inference(data, gen_model, ode_model, method, iterations = 10000, num_particles = 1, \ num_samples = 1000, warmup_steps = 500, init_scale = 0.1, \ seed = 12, lr = 0.5, return_sites = ("_RETURN")): torch_data = torch.tensor(data, dtype=torch.float) if isinstance(ode_model,ForwardSensManualJacobians) or \ isinstance(ode_model,ForwardSensTorchJacobians): ode_op = ForwardSensOp elif isinstance(ode_model,AdjointSensManualJacobians) or \ isinstance(ode_model,AdjointSensTorchJacobians): ode_op = AdjointSensOp else: raise ValueError( 'Unknown sensitivity solver: Use "Forward" or "Adjoint"') model = gen_model(ode_op, ode_model) pyro.set_rng_seed(seed) pyro.clear_param_store() if method == 'VI': guide = AutoMultivariateNormal(model, init_scale=init_scale) optim = AdagradRMSProp({"eta": lr}) if num_particles == 1: svi = SVI(model, guide, optim, loss=Trace_ELBO()) else: svi = SVI(model, guide, optim, loss=Trace_ELBO(num_particles=num_particles, vectorize_particles=True)) loss_trace = [] t0 = timer.time() for j in range(iterations): loss = svi.step(torch_data) loss_trace.append(loss) if j % 500 == 0: print("[iteration %04d] loss: %.4f" % (j + 1, np.mean(loss_trace[max(0, j - 1000):j + 1]))) t1 = timer.time() print('VI time: ', t1 - t0) predictive = Predictive( model, guide=guide, num_samples=num_samples, return_sites=return_sites) #"ode_params", "scale", vb_samples = predictive(torch_data) return vb_samples elif method == 'NUTS': nuts_kernel = NUTS(model, adapt_step_size=True, init_strategy=init_to_median) mcmc = MCMC(nuts_kernel, num_samples=iterations, warmup_steps=warmup_steps, num_chains=2) t0 = timer.time() mcmc.run(torch_data) t1 = timer.time() print('NUTS time: ', t1 - t0) hmc_samples = { k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items() } return hmc_samples else: raise ValueError('Unknown method: Use "NUTS" or "VI"')