示例#1
0
    def test_elbo_nonreparameterized(self):
        pyro.clear_param_store()

        def model():
            p_latent = pyro.sample("p_latent", dist.beta, self.alpha0, self.beta0)
            pyro.map_data("aaa",
                          self.data, lambda i, x: pyro.observe(
                              "obs_{}".format(i), dist.bernoulli, x, p_latent),
                          batch_size=self.batch_size)
            return p_latent

        def guide():
            alpha_q_log = pyro.param("alpha_q_log",
                                     Variable(self.log_alpha_n.data + 0.17, requires_grad=True))
            beta_q_log = pyro.param("beta_q_log",
                                    Variable(self.log_beta_n.data - 0.143, requires_grad=True))
            alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
            pyro.sample("p_latent", dist.beta, alpha_q, beta_q)
            pyro.map_data("aaa", self.data, lambda i, x: None, batch_size=self.batch_size)

        adam = optim.Adam({"lr": .001, "betas": (0.97, 0.999)})
        svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False)

        for k in range(10001):
            svi.step()

            alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n)
            beta_error = param_abs_error("beta_q_log", self.log_beta_n)

        self.assertEqual(0.0, alpha_error, prec=0.08)
        self.assertEqual(0.0, beta_error, prec=0.08)
示例#2
0
    def do_elbo_test(self, reparameterized, n_steps):
        pyro.clear_param_store()
        pt_guide = LogNormalNormalGuide(self.log_mu_n.data + 0.17,
                                        self.log_tau_n.data - 0.143)

        def model():
            mu_latent = pyro.sample("mu_latent", dist.normal,
                                    self.mu0, torch.pow(self.tau0, -0.5))
            sigma = torch.pow(self.tau, -0.5)
            pyro.observe("obs0", dist.lognormal, self.data[0], mu_latent, sigma)
            pyro.observe("obs1", dist.lognormal, self.data[1], mu_latent, sigma)
            return mu_latent

        def guide():
            pyro.module("mymodule", pt_guide)
            mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log)
            sigma = torch.pow(tau_q, -0.5)
            pyro.sample("mu_latent", dist.Normal(mu_q, sigma, reparameterized=reparameterized))

        adam = optim.Adam({"lr": .0005, "betas": (0.96, 0.999)})
        svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False)

        for k in range(n_steps):
            svi.step()

        mu_error = param_abs_error("mymodule$$$mu_q_log", self.log_mu_n)
        tau_error = param_abs_error("mymodule$$$tau_q_log", self.log_tau_n)
        self.assertEqual(0.0, mu_error, prec=0.07)
        self.assertEqual(0.0, tau_error, prec=0.07)
示例#3
0
    def test_elbo_nonreparameterized(self):
        pyro.clear_param_store()

        def model():
            lambda_latent = pyro.sample("lambda_latent", dist.gamma, self.alpha0, self.beta0)
            pyro.observe("obs0", dist.exponential, self.data[0], lambda_latent)
            pyro.observe("obs1", dist.exponential, self.data[1], lambda_latent)
            return lambda_latent

        def guide():
            alpha_q_log = pyro.param(
                "alpha_q_log",
                Variable(self.log_alpha_n.data + 0.17, requires_grad=True))
            beta_q_log = pyro.param(
                "beta_q_log",
                Variable(self.log_beta_n.data - 0.143, requires_grad=True))
            alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
            pyro.sample("lambda_latent", dist.gamma, alpha_q, beta_q)

        adam = optim.Adam({"lr": .0003, "betas": (0.97, 0.999)})
        svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False)

        for k in range(10001):
            svi.step()

        alpha_error = param_abs_error("alpha_q_log", self.log_alpha_n)
        beta_error = param_abs_error("beta_q_log", self.log_beta_n)
        self.assertEqual(0.0, alpha_error, prec=0.08)
        self.assertEqual(0.0, beta_error, prec=0.08)
示例#4
0
def assert_ok(model, guide, elbo):
    """
    Assert that inference works without warnings or errors.
    """
    pyro.clear_param_store()
    inference = SVI(model, guide, Adam({"lr": 1e-6}), elbo)
    inference.step()
示例#5
0
def test_module_nn(nn_module):
    pyro.clear_param_store()
    nn_module = nn_module()
    assert pyro.get_param_store()._params == {}
    pyro.module("module", nn_module)
    for name in pyro.get_param_store().get_all_param_names():
        assert pyro.params.user_param_name(name) in nn_module.state_dict().keys()
示例#6
0
def test_bern_elbo_gradient(enum_discrete, trace_graph):
    pyro.clear_param_store()
    num_particles = 2000

    def model():
        p = Variable(torch.Tensor([0.25]))
        pyro.sample("z", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", Variable(torch.Tensor([0.5]), requires_grad=True))
        pyro.sample("z", dist.Bernoulli(p))

    print("Computing gradients using surrogate loss")
    Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
    elbo = Elbo(enum_discrete=enum_discrete,
                num_particles=(1 if enum_discrete else num_particles))
    with xfail_if_not_implemented():
        elbo.loss_and_grads(model, guide)
    params = sorted(pyro.get_param_store().get_all_param_names())
    assert params, "no params found"
    actual_grads = {name: pyro.param(name).grad.clone() for name in params}

    print("Computing gradients using finite difference")
    elbo = Trace_ELBO(num_particles=num_particles)
    expected_grads = finite_difference(lambda: elbo.loss(model, guide))

    for name in params:
        print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data,
                                 expected_grads[name].data))
    assert_equal(actual_grads, expected_grads, prec=0.1)
示例#7
0
def test_iter_discrete_traces_vector(graph_type):
    pyro.clear_param_store()

    def model():
        p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]])))
        ps = pyro.param("ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4],
                                                     [0.4, 0.3, 0.2, 0.1]])))
        x = pyro.sample("x", dist.Bernoulli(p))
        y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
        assert x.size() == (2, 1)
        assert y.size() == (2, 1)
        return dict(x=x, y=y)

    traces = list(iter_discrete_traces(graph_type, model))

    p = pyro.param("p").data
    ps = pyro.param("ps").data
    assert len(traces) == 2 * ps.size(-1)

    for scale, trace in traces:
        x = trace.nodes["x"]["value"].data.squeeze().long()[0]
        y = trace.nodes["y"]["value"].data.squeeze().long()[0]
        expected_scale = torch.exp(dist.Bernoulli(p).log_pdf(x) *
                                   dist.Categorical(ps, one_hot=False).log_pdf(y))
        expected_scale = expected_scale.data.view(-1)[0]
        assert_equal(scale, expected_scale)
示例#8
0
 def test_random_module(self):
     pyro.clear_param_store()
     lifted_tr = poutine.trace(pyro.random_module("name", self.model, prior=self.prior)).get_trace()
     for name in lifted_tr.nodes.keys():
         if lifted_tr.nodes[name]["type"] == "param":
             assert lifted_tr.nodes[name]["type"] == "sample"
             assert not lifted_tr.nodes[name]["is_observed"]
示例#9
0
def test_gmm_batch_iter_discrete_traces(model, data_size, graph_type):
    pyro.clear_param_store()
    data = torch.arange(0, data_size)
    model = config_enumerate(model)
    traces = list(iter_discrete_traces(graph_type, model, data=data))
    # This vectorized version is independent of data_size:
    assert len(traces) == 2
示例#10
0
def test_dynamic_lr(scheduler, num_steps):
    pyro.clear_param_store()

    def model():
        sample = pyro.sample('latent', Normal(torch.tensor(0.), torch.tensor(0.3)))
        return pyro.sample('obs', Normal(sample, torch.tensor(0.2)), obs=torch.tensor(0.1))

    def guide():
        loc = pyro.param('loc', torch.tensor(0.))
        scale = pyro.param('scale', torch.tensor(0.5))
        pyro.sample('latent', Normal(loc, scale))

    svi = SVI(model, guide, scheduler, loss=TraceGraph_ELBO())
    for epoch in range(2):
        scheduler.set_epoch(epoch)
        for _ in range(num_steps):
            svi.step()
        if epoch == 1:
            loc = pyro.param('loc')
            scale = pyro.param('scale')
            opt = scheduler.optim_objs[loc].optimizer
            assert opt.state_dict()['param_groups'][0]['lr'] == 0.02
            assert opt.state_dict()['param_groups'][0]['initial_lr'] == 0.01
            opt = scheduler.optim_objs[scale].optimizer
            assert opt.state_dict()['param_groups'][0]['lr'] == 0.02
            assert opt.state_dict()['param_groups'][0]['initial_lr'] == 0.01
示例#11
0
文件: test_jit.py 项目: lewisKit/pyro
def test_dirichlet_bernoulli(Elbo, vectorized):
    pyro.clear_param_store()
    data = torch.tensor([1.0] * 6 + [0.0] * 4)

    def model1(data):
        concentration0 = torch.tensor([10.0, 10.0])
        f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1]
        for i in pyro.irange("irange", len(data)):
            pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

    def model2(data):
        concentration0 = torch.tensor([10.0, 10.0])
        f = pyro.sample("latent_fairness", dist.Dirichlet(concentration0))[1]
        pyro.sample("obs", dist.Bernoulli(f).expand_by(data.shape).independent(1),
                    obs=data)

    model = model2 if vectorized else model1

    def guide(data):
        concentration_q = pyro.param("concentration_q", torch.tensor([15.0, 15.0]),
                                     constraint=constraints.positive)
        pyro.sample("latent_fairness", dist.Dirichlet(concentration_q))

    elbo = Elbo(num_particles=7, strict_enumeration_warning=False)
    optim = Adam({"lr": 0.0005, "betas": (0.90, 0.999)})
    svi = SVI(model, guide, optim, elbo)
    for step in range(40):
        svi.step(data)
示例#12
0
def test_elbo_bern(quantity, enumerate1):
    pyro.clear_param_store()
    num_particles = 1 if enumerate1 else 10000
    prec = 0.001 if enumerate1 else 0.1
    q = pyro.param("q", torch.tensor(0.5, requires_grad=True))
    kl = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(0.25))

    def model():
        with pyro.iarange("particles", num_particles):
            pyro.sample("z", dist.Bernoulli(0.25).expand_by([num_particles]))

    @config_enumerate(default=enumerate1)
    def guide():
        q = pyro.param("q")
        with pyro.iarange("particles", num_particles):
            pyro.sample("z", dist.Bernoulli(q).expand_by([num_particles]))

    elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                          strict_enumeration_warning=any([enumerate1]))

    if quantity == "loss":
        actual = elbo.loss(model, guide) / num_particles
        expected = kl.item()
        assert_equal(actual, expected, prec=prec, msg="".join([
            "\nexpected = {}".format(expected),
            "\n  actual = {}".format(actual),
        ]))
    else:
        elbo.loss_and_grads(model, guide)
        actual = q.grad / num_particles
        expected = grad(kl, [q])[0]
        assert_equal(actual, expected, prec=prec, msg="".join([
            "\nexpected = {}".format(expected.detach().cpu().numpy()),
            "\n  actual = {}".format(actual.detach().cpu().numpy()),
        ]))
示例#13
0
    def do_elbo_test(self, reparameterized, n_steps):
        pyro.clear_param_store()

        def model():
            mu_latent = pyro.sample("mu_latent", dist.normal,
                                    self.mu0, torch.pow(self.lam0, -0.5))
            pyro.map_data("aaa", self.data, lambda i,
                          x: pyro.observe(
                              "obs_%d" % i, dist.normal,
                              x, mu_latent, torch.pow(self.lam, -0.5)),
                          batch_size=self.batch_size)
            return mu_latent

        def guide():
            mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.134 * torch.ones(2),
                                               requires_grad=True))
            log_sig_q = pyro.param("log_sig_q", Variable(
                                   self.analytic_log_sig_n.data - 0.14 * torch.ones(2),
                                   requires_grad=True))
            sig_q = torch.exp(log_sig_q)
            pyro.sample("mu_latent", dist.Normal(mu_q, sig_q, reparameterized=reparameterized))
            pyro.map_data("aaa", self.data, lambda i, x: None,
                          batch_size=self.batch_size)

        adam = optim.Adam({"lr": .001})
        svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False)

        for k in range(n_steps):
            svi.step()

            mu_error = param_mse("mu_q", self.analytic_mu_n)
            log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n)

        self.assertEqual(0.0, mu_error, prec=0.05)
        self.assertEqual(0.0, log_sig_error, prec=0.05)
示例#14
0
def test_gmm_iter_discrete_traces(data_size, graph_type, model):
    pyro.clear_param_store()
    data = torch.arange(0, data_size)
    model = config_enumerate(model)
    traces = list(iter_discrete_traces(graph_type, model, data=data, verbose=True))
    # This non-vectorized version is exponential in data_size:
    assert len(traces) == 2**data_size
示例#15
0
def main(args):
    pyro.clear_param_store()
    data = build_linear_dataset(N, p)
    if args.cuda:
        # make tensors and modules CUDA
        data = data.cuda()
        softplus.cuda()
        regression_model.cuda()
    for j in range(args.num_epochs):
        if args.batch_size == N:
            # use the entire data set
            epoch_loss = svi.step(data)
        else:
            # mini batch
            epoch_loss = 0.0
            perm = torch.randperm(N) if not args.cuda else torch.randperm(N).cuda()
            # shuffle data
            data = data[perm]
            # get indices of each batch
            all_batches = get_batch_indices(N, args.batch_size)
            for ix, batch_start in enumerate(all_batches[:-1]):
                batch_end = all_batches[ix + 1]
                batch_data = data[batch_start: batch_end]
                epoch_loss += svi.step(batch_data)
        if j % 100 == 0:
            print("epoch avg loss {}".format(epoch_loss/float(N)))
示例#16
0
def test_elbo_hmm_in_guide(enumerate1, num_steps):
    pyro.clear_param_store()
    data = torch.ones(num_steps)
    init_probs = torch.tensor([0.5, 0.5])

    def model(data):
        transition_probs = pyro.param("transition_probs",
                                      torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                      constraint=constraints.simplex)
        emission_probs = pyro.param("emission_probs",
                                    torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                    constraint=constraints.simplex)

        x = None
        for i, y in enumerate(data):
            probs = init_probs if x is None else transition_probs[x]
            x = pyro.sample("x_{}".format(i), dist.Categorical(probs))
            pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=y)

    @config_enumerate(default=enumerate1)
    def guide(data):
        transition_probs = pyro.param("transition_probs",
                                      torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                      constraint=constraints.simplex)
        x = None
        for i, y in enumerate(data):
            probs = init_probs if x is None else transition_probs[x]
            x = pyro.sample("x_{}".format(i), dist.Categorical(probs))

    elbo = TraceEnum_ELBO(max_iarange_nesting=0)
    elbo.loss_and_grads(model, guide, data)

    # These golden values simply test agreement between parallel and sequential.
    expected_grads = {
        2: {
            "transition_probs": [[0.1029949, -0.1029949], [0.1029949, -0.1029949]],
            "emission_probs": [[0.75, -0.75], [0.25, -0.25]],
        },
        3: {
            "transition_probs": [[0.25748726, -0.25748726], [0.25748726, -0.25748726]],
            "emission_probs": [[1.125, -1.125], [0.375, -0.375]],
        },
        10: {
            "transition_probs": [[1.64832076, -1.64832076], [1.64832076, -1.64832076]],
            "emission_probs": [[3.75, -3.75], [1.25, -1.25]],
        },
        20: {
            "transition_probs": [[3.70781687, -3.70781687], [3.70781687, -3.70781687]],
            "emission_probs": [[7.5, -7.5], [2.5, -2.5]],
        },
    }

    for name, value in pyro.get_param_store().named_parameters():
        actual = value.grad
        expected = torch.tensor(expected_grads[num_steps][name])
        assert_equal(actual, expected, msg=''.join([
            '\nexpected {}.grad = {}'.format(name, expected.cpu().numpy()),
            '\n  actual {}.grad = {}'.format(name, actual.detach().cpu().numpy()),
        ]))
示例#17
0
    def test_duplicate_obs_name(self):
        pyro.clear_param_store()

        adam = optim.Adam({"lr": .001})
        svi = SVI(self.duplicate_obs, self.guide, adam, loss="ELBO", trace_graph=False)

        with pytest.raises(RuntimeError):
            svi.step()
示例#18
0
def assert_error(model, guide, elbo):
    """
    Assert that inference fails with an error.
    """
    pyro.clear_param_store()
    inference = SVI(model,  guide, Adam({"lr": 1e-6}), elbo)
    with pytest.raises((NotImplementedError, UserWarning, KeyError, ValueError, RuntimeError)):
        inference.step()
示例#19
0
    def test_extra_samples(self):
        pyro.clear_param_store()

        adam = optim.Adam({"lr": .001})
        svi = SVI(self.model, self.guide, adam, loss="ELBO", trace_graph=False)

        with pytest.warns(Warning):
            svi.step()
示例#20
0
def test_svi_step_smoke(model, guide, enum_discrete, trace_graph):
    pyro.clear_param_store()
    data = Variable(torch.Tensor([0, 1, 9]))

    optimizer = pyro.optim.Adam({"lr": .001})
    inference = SVI(model, guide, optimizer, loss="ELBO",
                    trace_graph=trace_graph, enum_discrete=enum_discrete)
    with xfail_if_not_implemented():
        inference.step(data)
示例#21
0
def test_random_module(nn_module):
    pyro.clear_param_store()
    nn_module = nn_module()
    p = torch.ones(2, 2)
    prior = dist.Bernoulli(p)
    lifted_mod = pyro.random_module("module", nn_module, prior)
    nn_module = lifted_mod()
    for name, parameter in nn_module.named_parameters():
        assert torch.equal(torch.ones(2, 2), parameter.data)
    def do_elbo_test(self, reparameterized, n_steps, lr, prec, beta1,
                     difficulty=1.0, model_permutation=False):
        n_repa_nodes = torch.sum(self.which_nodes_reparam) if not reparameterized \
            else len(self.q_topo_sort)
        logger.info((" - - - DO GAUSSIAN %d-LAYERED PYRAMID ELBO TEST " +
                     "(with a total of %d RVs) [reparameterized=%s; %d/%d; perm=%s] - - -") %
                    (self.N, (2 ** self.N) - 1, reparameterized, n_repa_nodes,
                     len(self.q_topo_sort), model_permutation))
        pyro.clear_param_store()

        # check graph structure is as expected but only for N=2
        if self.N == 2:
            guide_trace = pyro.poutine.trace(self.guide,
                                             graph_type="dense").get_trace(reparameterized=reparameterized,
                                                                           model_permutation=model_permutation,
                                                                           difficulty=difficulty)
            expected_nodes = set(['log_sig_1R', 'kappa_1_1L', '_INPUT', 'constant_term_loc_latent_1R', '_RETURN',
                                  'loc_latent_1R', 'loc_latent_1', 'constant_term_loc_latent_1', 'loc_latent_1L',
                                  'constant_term_loc_latent_1L', 'log_sig_1L', 'kappa_1_1R', 'kappa_1R_1L',
                                  'log_sig_1'])
            expected_edges = set([('loc_latent_1R', 'loc_latent_1'), ('loc_latent_1L', 'loc_latent_1R'),
                                  ('loc_latent_1L', 'loc_latent_1')])
            assert expected_nodes == set(guide_trace.nodes)
            assert expected_edges == set(guide_trace.edges)

        adam = optim.Adam({"lr": lr, "betas": (beta1, 0.999)})
        svi = SVI(self.model, self.guide, adam, loss=TraceGraph_ELBO())

        for step in range(n_steps):
            t0 = time.time()
            svi.step(reparameterized=reparameterized, model_permutation=model_permutation, difficulty=difficulty)

            if step % 5000 == 0 or step == n_steps - 1:
                log_sig_errors = []
                for node in self.target_lambdas:
                    target_log_sig = -0.5 * torch.log(self.target_lambdas[node])
                    log_sig_error = param_mse('log_sig_' + node, target_log_sig)
                    log_sig_errors.append(log_sig_error)
                max_log_sig_error = np.max(log_sig_errors)
                min_log_sig_error = np.min(log_sig_errors)
                mean_log_sig_error = np.mean(log_sig_errors)
                leftmost_node = self.q_topo_sort[0]
                leftmost_constant_error = param_mse('constant_term_' + leftmost_node,
                                                    self.target_leftmost_constant)
                almost_leftmost_constant_error = param_mse('constant_term_' + leftmost_node[:-1] + 'R',
                                                           self.target_almost_leftmost_constant)

                logger.debug("[mean function constant errors (partial)]   %.4f  %.4f" %
                             (leftmost_constant_error, almost_leftmost_constant_error))
                logger.debug("[min/mean/max log(scale) errors]   %.4f  %.4f   %.4f" %
                             (min_log_sig_error, mean_log_sig_error, max_log_sig_error))
                logger.debug("[step time = %.3f;  N = %d;  step = %d]\n" % (time.time() - t0, self.N, step))

        assert_equal(0.0, max_log_sig_error, prec=prec)
        assert_equal(0.0, leftmost_constant_error, prec=prec)
        assert_equal(0.0, almost_leftmost_constant_error, prec=prec)
示例#23
0
文件: conftest.py 项目: lewisKit/pyro
def pytest_runtest_setup(item):
    pyro.clear_param_store()
    if item.get_marker("disable_validation"):
        pyro.enable_validation(False)
    else:
        pyro.enable_validation(True)
    test_initialize_marker = item.get_marker("init")
    if test_initialize_marker:
        rng_seed = test_initialize_marker.kwargs["rng_seed"]
        pyro.set_rng_seed(rng_seed)
示例#24
0
def test_svi_step_smoke(model, guide, enumerate1):
    pyro.clear_param_store()
    data = torch.tensor([0.0, 1.0, 9.0])

    guide = config_enumerate(guide, default=enumerate1)
    optimizer = pyro.optim.Adam({"lr": .001})
    elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                          strict_enumeration_warning=any([enumerate1]))
    inference = SVI(model, guide, optimizer, loss=elbo)
    inference.step(data)
示例#25
0
def test_non_mean_field_bern_normal_elbo_gradient(enumerate1, pi1, pi2, pi3, include_z=True):
    pyro.clear_param_store()
    num_particles = 10000

    def model():
        with pyro.iarange("particles", num_particles):
            q3 = pyro.param("q3", torch.tensor(pi3, requires_grad=True))
            y = pyro.sample("y", dist.Bernoulli(q3).expand_by([num_particles]))
            if include_z:
                pyro.sample("z", dist.Normal(0.55 * y + q3, 1.0))

    def guide():
        q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True))
        q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True))
        with pyro.iarange("particles", num_particles):
            y = pyro.sample("y", dist.Bernoulli(q1).expand_by([num_particles]), infer={"enumerate": enumerate1})
            if include_z:
                pyro.sample("z", dist.Normal(q2 * y + 0.10, 1.0))

    logger.info("Computing gradients using surrogate loss")
    elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                          strict_enumeration_warning=any([enumerate1]))
    elbo.loss_and_grads(model, guide)
    actual_grad_q1 = pyro.param('q1').grad / num_particles
    if include_z:
        actual_grad_q2 = pyro.param('q2').grad / num_particles
    actual_grad_q3 = pyro.param('q3').grad / num_particles

    logger.info("Computing analytic gradients")
    q1 = torch.tensor(pi1, requires_grad=True)
    q2 = torch.tensor(pi2, requires_grad=True)
    q3 = torch.tensor(pi3, requires_grad=True)
    elbo = kl_divergence(dist.Bernoulli(q1), dist.Bernoulli(q3))
    if include_z:
        elbo = elbo + q1 * kl_divergence(dist.Normal(q2 + 0.10, 1.0), dist.Normal(q3 + 0.55, 1.0))
        elbo = elbo + (1.0 - q1) * kl_divergence(dist.Normal(0.10, 1.0), dist.Normal(q3, 1.0))
        expected_grad_q1, expected_grad_q2, expected_grad_q3 = grad(elbo, [q1, q2, q3])
    else:
        expected_grad_q1, expected_grad_q3 = grad(elbo, [q1, q3])

    prec = 0.04 if enumerate1 is None else 0.02

    assert_equal(actual_grad_q1, expected_grad_q1, prec=prec, msg="".join([
        "\nq1 expected = {}".format(expected_grad_q1.data.cpu().numpy()),
        "\nq1   actual = {}".format(actual_grad_q1.data.cpu().numpy()),
    ]))
    if include_z:
        assert_equal(actual_grad_q2, expected_grad_q2, prec=prec, msg="".join([
            "\nq2 expected = {}".format(expected_grad_q2.data.cpu().numpy()),
            "\nq2   actual = {}".format(actual_grad_q2.data.cpu().numpy()),
        ]))
    assert_equal(actual_grad_q3, expected_grad_q3, prec=prec, msg="".join([
        "\nq3 expected = {}".format(expected_grad_q3.data.cpu().numpy()),
        "\nq3   actual = {}".format(actual_grad_q3.data.cpu().numpy()),
    ]))
示例#26
0
 def test_random_module_prior_dict(self):
     pyro.clear_param_store()
     lifted_nn = pyro.random_module("name", self.model, prior=self.nn_prior)
     lifted_tr = poutine.trace(lifted_nn).get_trace()
     for key_name in lifted_tr.nodes.keys():
         name = pyro.params.user_param_name(key_name)
         if name in {'fc.weight', 'fc.prior'}:
             dist_name = name[3:]
             assert dist_name + "_prior" == lifted_tr.nodes[key_name]['fn'].__name__
             assert lifted_tr.nodes[key_name]["type"] == "sample"
             assert not lifted_tr.nodes[key_name]["is_observed"]
示例#27
0
def assert_warning(model, guide, elbo):
    """
    Assert that inference works but with a warning.
    """
    pyro.clear_param_store()
    inference = SVI(model,  guide, Adam({"lr": 1e-6}), elbo)
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        inference.step()
        assert len(w), 'No warnings were raised'
        for warning in w:
            logger.info(warning)
示例#28
0
    def setUp(self):
        pyro.clear_param_store()

        def mu1_prior(tensor, *args, **kwargs):
            flat_tensor = tensor.view(-1)
            m = Variable(torch.zeros(flat_tensor.size(0)))
            s = Variable(torch.ones(flat_tensor.size(0)))
            return Normal(m, s).sample().view(tensor.size())

        def sigma1_prior(tensor, *args, **kwargs):
            flat_tensor = tensor.view(-1)
            m = Variable(torch.zeros(flat_tensor.size(0)))
            s = Variable(torch.ones(flat_tensor.size(0)))
            return Normal(m, s).sample().view(tensor.size())

        def mu2_prior(tensor, *args, **kwargs):
            flat_tensor = tensor.view(-1)
            m = Variable(torch.zeros(flat_tensor.size(0)))
            return Bernoulli(m).sample().view(tensor.size())

        def sigma2_prior(tensor, *args, **kwargs):
            return sigma1_prior(tensor)

        def bias_prior(tensor, *args, **kwargs):
            return mu2_prior(tensor)

        def weight_prior(tensor, *args, **kwargs):
            return sigma1_prior(tensor)

        def stoch_fn(tensor, *args, **kwargs):
            mu = Variable(torch.zeros(tensor.size()))
            sigma = Variable(torch.ones(tensor.size()))
            return pyro.sample("sample", Normal(mu, sigma))

        def guide():
            mu1 = pyro.param("mu1", Variable(torch.randn(2), requires_grad=True))
            sigma1 = pyro.param("sigma1", Variable(torch.ones(2), requires_grad=True))
            pyro.sample("latent1", Normal(mu1, sigma1))

            mu2 = pyro.param("mu2", Variable(torch.randn(2), requires_grad=True))
            sigma2 = pyro.param("sigma2", Variable(torch.ones(2), requires_grad=True))
            latent2 = pyro.sample("latent2", Normal(mu2, sigma2))
            return latent2

        self.model = Model()
        self.guide = guide
        self.prior = mu1_prior
        self.prior_dict = {"mu1": mu1_prior, "sigma1": sigma1_prior, "mu2": mu2_prior, "sigma2": sigma2_prior}
        self.partial_dict = {"mu1": mu1_prior, "sigma1": sigma1_prior}
        self.nn_prior = {"fc.bias": bias_prior, "fc.weight": weight_prior}
        self.fn = stoch_fn
        self.data = Variable(torch.randn(2, 2))
示例#29
0
    def setUp(self):
        pyro.clear_param_store()

        def loc1_prior(tensor, *args, **kwargs):
            flat_tensor = tensor.view(-1)
            m = torch.zeros(flat_tensor.size(0))
            s = torch.ones(flat_tensor.size(0))
            return Normal(m, s).sample().view(tensor.size())

        def scale1_prior(tensor, *args, **kwargs):
            flat_tensor = tensor.view(-1)
            m = torch.zeros(flat_tensor.size(0))
            s = torch.ones(flat_tensor.size(0))
            return Normal(m, s).sample().view(tensor.size()).exp()

        def loc2_prior(tensor, *args, **kwargs):
            flat_tensor = tensor.view(-1)
            m = torch.zeros(flat_tensor.size(0))
            return Bernoulli(m).sample().view(tensor.size())

        def scale2_prior(tensor, *args, **kwargs):
            return scale1_prior(tensor)

        def bias_prior(tensor, *args, **kwargs):
            return loc2_prior(tensor)

        def weight_prior(tensor, *args, **kwargs):
            return scale1_prior(tensor)

        def stoch_fn(tensor, *args, **kwargs):
            loc = torch.zeros(tensor.size())
            scale = torch.ones(tensor.size())
            return pyro.sample("sample", Normal(loc, scale))

        def guide():
            loc1 = pyro.param("loc1", torch.randn(2, requires_grad=True))
            scale1 = pyro.param("scale1", torch.ones(2, requires_grad=True))
            pyro.sample("latent1", Normal(loc1, scale1))

            loc2 = pyro.param("loc2", torch.randn(2, requires_grad=True))
            scale2 = pyro.param("scale2", torch.ones(2, requires_grad=True))
            latent2 = pyro.sample("latent2", Normal(loc2, scale2))
            return latent2

        self.model = Model()
        self.guide = guide
        self.prior = scale1_prior
        self.prior_dict = {"loc1": loc1_prior, "scale1": scale1_prior, "loc2": loc2_prior, "scale2": scale2_prior}
        self.partial_dict = {"loc1": loc1_prior, "scale1": scale1_prior}
        self.nn_prior = {"fc.bias": bias_prior, "fc.weight": weight_prior}
        self.fn = stoch_fn
        self.data = torch.randn(2, 2)
示例#30
0
def test_named_dict():
    pyro.clear_param_store()

    def model():
        latent = named.Dict("latent")
        loc = latent["loc"].param_(torch.zeros(1))
        foo = latent["foo"].sample_(dist.Normal(loc, torch.ones(1)))
        latent["bar"].sample_(dist.Normal(loc, torch.ones(1)), obs=foo)
        latent["x"].z.sample_(dist.Normal(loc, torch.ones(1)))

    tr = poutine.trace(model).get_trace()
    assert get_sample_names(tr) == set(["latent['foo']", "latent['x'].z"])
    assert get_observe_names(tr) == set(["latent['bar']"])
    assert get_param_names(tr) == set(["latent['loc']"])
示例#31
0
def main(smoke_test=False):
    epochs = 2 if smoke_test == True else 50
    batch_size = 128
    seed = 0

    x_ch = 1
    z_dim = 32

    # 乱数シード初期化
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    pyro.set_rng_seed(seed)

    date_and_time = datetime.datetime.now().strftime('%Y-%m%d-%H%M')
    save_root = f'./results/pyro/{date_and_time}'
    if not os.path.exists(save_root):
        os.makedirs(save_root)

    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'

    pyro.clear_param_store()  # Pyroのパラメーター初期化
    pyro.enable_validation(smoke_test)  # デバッグ用。NaNチェック、分布の検証、引数やサポート値チェックなど
    pyro.distributions.enable_validation(False)

    root = '/mnt/hdd/sika/Datasets'
    train_loader, test_loader = make_MNIST_loader(root, batch_size=batch_size)

    # modelメソッドとguideメソッドを持つクラスのインスタンスを作成
    vae = VAE(x_ch, z_dim).to(device)

    # 最適化アルゴリズムはPyroOptimでラッピングして使用する
    optimizer = pyro.optim.PyroOptim(torch.optim.Adam, {'lr': 1e-3})

    # SVI(Stochastic Variational Inference)のインスタンスを作成
    svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

    x_fixed, _ = next(iter(test_loader))  # 固定の画像
    x_fixed = x_fixed[:8].to(device)
    z_fixed = torch.randn([64, z_dim], device=device)  # 固定の潜在変数
    x_dummy = torch.zeros(64,
                          x_fixed.size(1),
                          x_fixed.size(2),
                          x_fixed.size(3),
                          device=device)  # sample用

    train_loss_list, test_loss_list = [], []
    for epoch in range(1, epochs + 1):
        train_loss_list.append(
            learn(svi, epoch, train_loader, device, train=True))
        test_loss_list.append(
            learn(svi, epoch, test_loader, device, train=False))

        print(f'    [Epoch {epoch}] train loss {train_loss_list[-1]:.4f}')
        print(f'    [Epoch {epoch}] test  loss {test_loss_list[-1]:.4f}\n')

        # 損失値のグラフを作成し保存
        plt.plot(list(range(1, epoch + 1)), train_loss_list, label='train')
        plt.plot(list(range(1, epoch + 1)), test_loss_list, label='test')
        plt.xlabel('epochs')
        plt.ylabel('loss')
        plt.legend()
        plt.savefig(os.path.join(save_root, 'loss.png'))
        plt.close()

        # 再構成画像
        x_reconst = reconstruct_image(vae.encoder, vae.decoder, x_fixed)
        save_image(torch.cat([x_fixed, x_reconst], dim=0),
                   os.path.join(save_root, f'reconst_{epoch}.png'),
                   nrow=8)

        # 補間画像
        x_interpol = interpolate_image(vae.encoder, vae.decoder, x_fixed)
        save_image(x_interpol,
                   os.path.join(save_root, f'interpol_{epoch}.png'),
                   nrow=8)

        # 生成画像(潜在変数固定)
        x_generate = generate_image(vae.decoder, z_fixed)
        save_image(x_generate,
                   os.path.join(save_root, f'generate_{epoch}.png'),
                   nrow=8)

        # 生成画像(ランダムサンプリング)
        x_sample = sample_image(vae.model, x_dummy)
        save_image(x_sample,
                   os.path.join(save_root, f'sample_{epoch}.png'),
                   nrow=8)
示例#32
0
def main(args):
    pyro.set_rng_seed(0)
    pyro.clear_param_store()
    pyro.enable_validation(__debug__)

    # load data
    if args.dataset == "dipper":
        capture_history_file = os.path.dirname(
            os.path.abspath(__file__)) + '/dipper_capture_history.csv'
    elif args.dataset == "vole":
        capture_history_file = os.path.dirname(
            os.path.abspath(__file__)) + '/meadow_voles_capture_history.csv'
    else:
        raise ValueError("Available datasets are \'dipper\' and \'vole\'.")

    capture_history = torch.tensor(
        np.genfromtxt(capture_history_file, delimiter=',')).float()[:, 1:]
    N, T = capture_history.shape
    print(
        "Loaded {} capture history for {} individuals collected over {} time periods."
        .format(args.dataset, N, T))

    if args.dataset == "dipper" and args.model in ["4", "5"]:
        sex_file = os.path.dirname(
            os.path.abspath(__file__)) + '/dipper_sex.csv'
        sex = torch.tensor(np.genfromtxt(sex_file, delimiter=',')).float()[:,
                                                                           1]
        print("Loaded dipper sex data.")
    elif args.dataset == "vole" and args.model in ["4", "5"]:
        raise ValueError(
            "Cannot run model_{} on meadow voles data, since we lack sex " +
            "information for these animals.".format(args.model))
    else:
        sex = None

    model = models[args.model]

    # we use poutine.block to only expose the continuous latent variables
    # in the models to AutoDiagonalNormal (all of which begin with 'phi'
    # or 'rho')
    def expose_fn(msg):
        return msg["name"][0:3] in ['phi', 'rho']

    # we use a mean field diagonal normal variational distributions (i.e. guide)
    # for the continuous latent variables.
    guide = AutoDiagonalNormal(poutine.block(model, expose_fn=expose_fn))

    # since we enumerate the discrete random variables,
    # we need to use TraceEnum_ELBO or TraceTMC_ELBO.
    optim = Adam({'lr': args.learning_rate})
    if args.tmc:
        elbo = TraceTMC_ELBO(max_plate_nesting=1)
        tmc_model = poutine.infer_config(model, lambda msg: {
            "num_samples": args.tmc_num_samples,
            "expand": False
        } if msg["infer"].get("enumerate", None) == "parallel" else {}
                                         )  # noqa: E501
        svi = SVI(tmc_model, guide, optim, elbo)
    else:
        elbo = TraceEnum_ELBO(max_plate_nesting=1,
                              num_particles=20,
                              vectorize_particles=True)
        svi = SVI(model, guide, optim, elbo)

    losses = []

    print(
        "Beginning training of model_{} with Stochastic Variational Inference."
        .format(args.model))

    for step in range(args.num_steps):
        loss = svi.step(capture_history, sex)
        losses.append(loss)
        if step % 20 == 0 and step > 0 or step == args.num_steps - 1:
            print("[iteration %03d] loss: %.3f" %
                  (step, np.mean(losses[-20:])))

    # evaluate final trained model
    elbo_eval = TraceEnum_ELBO(max_plate_nesting=1,
                               num_particles=2000,
                               vectorize_particles=True)
    svi_eval = SVI(model, guide, optim, elbo_eval)
    print("Final loss: %.4f" % svi_eval.evaluate_loss(capture_history, sex))
示例#33
0
    def do_fit_prior_test(self,
                          reparameterized,
                          n_steps,
                          loss,
                          debug=False,
                          lr=0.001):
        pyro.clear_param_store()

        def model():
            with pyro.plate("samples", self.sample_batch_size):
                pyro.sample(
                    "loc_latent",
                    dist.Normal(
                        torch.stack([self.loc0] * self.sample_batch_size,
                                    dim=0),
                        torch.stack([torch.pow(self.lam0, -0.5)] *
                                    self.sample_batch_size,
                                    dim=0),
                    ).to_event(1),
                )

        def guide():
            loc_q = pyro.param("loc_q", self.loc0.detach() + 0.134)
            log_sig_q = pyro.param(
                "log_sig_q", -0.5 * torch.log(self.lam0).data.detach() - 0.14)
            sig_q = torch.exp(log_sig_q)
            Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal
            with pyro.plate("samples", self.sample_batch_size):
                pyro.sample(
                    "loc_latent",
                    Normal(
                        torch.stack([loc_q] * self.sample_batch_size, dim=0),
                        torch.stack([sig_q] * self.sample_batch_size, dim=0),
                    ).to_event(1),
                )

        adam = optim.Adam({"lr": lr})
        svi = SVI(model, guide, adam, loss=loss)

        alpha = 0.99
        for k in range(n_steps):
            svi.step()
            if debug:
                loc_error = param_mse("loc_q", self.loc0)
                log_sig_error = param_mse("log_sig_q",
                                          -0.5 * torch.log(self.lam0))
                with torch.no_grad():
                    if k == 0:
                        (
                            avg_loglikelihood,
                            avg_penalty,
                        ) = loss._differentiable_loss_parts(model, guide)
                        avg_loglikelihood = torch_item(avg_loglikelihood)
                        avg_penalty = torch_item(avg_penalty)
                    loglikelihood, penalty = loss._differentiable_loss_parts(
                        model, guide)
                    avg_loglikelihood = alpha * avg_loglikelihood + (
                        1 - alpha) * torch_item(loglikelihood)
                    avg_penalty = alpha * avg_penalty + (
                        1 - alpha) * torch_item(penalty)
                if k % 100 == 0:
                    print(loc_error, log_sig_error)
                    print(avg_loglikelihood, avg_penalty)
                    print()

        loc_error = param_mse("loc_q", self.loc0)
        log_sig_error = param_mse("log_sig_q", -0.5 * torch.log(self.lam0))
        assert_equal(0.0, loc_error, prec=0.05)
        assert_equal(0.0, log_sig_error, prec=0.05)
示例#34
0
def run_inference(dataset_obj: SingleCellRNACountsDataset,
                  args) -> RemoveBackgroundPyroModel:
    """Run a full inference procedure, training a latent variable model.

    Args:
        dataset_obj: Input data in the form of a SingleCellRNACountsDataset
            object.
        args: Input command line parsed arguments.

    Returns:
         model: cellbender.model.RemoveBackgroundPyroModel that has had
            inference run.

    """

    # Get the trimmed count matrix (transformed if called for).
    count_matrix = dataset_obj.get_count_matrix()

    # Configure pyro options (skip validations to improve speed).
    pyro.enable_validation(False)
    pyro.distributions.enable_validation(False)
    pyro.set_rng_seed(0)
    pyro.clear_param_store()

    # Set up the variational autoencoder:

    # Encoder.
    encoder_z = EncodeZ(input_dim=count_matrix.shape[1],
                        hidden_dims=args.z_hidden_dims,
                        output_dim=args.z_dim,
                        input_transform='normalize')

    encoder_other = EncodeNonZLatents(
        n_genes=count_matrix.shape[1],
        z_dim=args.z_dim,
        hidden_dims=consts.ENC_HIDDEN_DIMS,
        log_count_crossover=dataset_obj.priors['log_counts_crossover'],
        prior_log_cell_counts=np.log1p(dataset_obj.priors['cell_counts']),
        input_transform='normalize')

    encoder = CompositeEncoder({'z': encoder_z, 'other': encoder_other})

    # Decoder.
    decoder = Decoder(input_dim=args.z_dim,
                      hidden_dims=args.z_hidden_dims[::-1],
                      output_dim=count_matrix.shape[1])

    # Set up the pyro model for variational inference.
    model = RemoveBackgroundPyroModel(model_type=args.model,
                                      encoder=encoder,
                                      decoder=decoder,
                                      dataset_obj=dataset_obj,
                                      use_cuda=args.use_cuda)

    # Load the dataset into DataLoaders.
    frac = args.training_fraction  # Fraction of barcodes to use for training
    batch_size = int(
        min(300, frac * dataset_obj.analyzed_barcode_inds.size / 2))
    train_loader, test_loader = \
        prep_data_for_training(dataset=count_matrix,
                               empty_drop_dataset=
                               dataset_obj.get_count_matrix_empties(),
                               random_state=dataset_obj.random,
                               batch_size=batch_size,
                               training_fraction=frac,
                               fraction_empties=args.fraction_empties,
                               shuffle=True,
                               use_cuda=args.use_cuda)

    # Set up the optimizer.
    optimizer = pyro.optim.clipped_adam.ClippedAdam
    optimizer_args = {'lr': args.learning_rate, 'clip_norm': 10.}

    # Set up a learning rate scheduler.
    minibatches_per_epoch = int(
        np.ceil(len(train_loader) / train_loader.batch_size).item())
    scheduler_args = {
        'optimizer': optimizer,
        'max_lr': args.learning_rate * 10,
        'steps_per_epoch': minibatches_per_epoch,
        'epochs': args.epochs,
        'optim_args': optimizer_args
    }
    scheduler = pyro.optim.OneCycleLR(scheduler_args)

    # Determine the loss function.
    if args.use_jit:

        # Call guide() once as a warm-up.
        model.guide(
            torch.zeros([10, dataset_obj.analyzed_gene_inds.size
                         ]).to(model.device))

        if args.model == "simple":
            loss_function = JitTrace_ELBO()
        else:
            loss_function = JitTraceEnum_ELBO(max_plate_nesting=1,
                                              strict_enumeration_warning=False)
    else:

        if args.model == "simple":
            loss_function = Trace_ELBO()
        else:
            loss_function = TraceEnum_ELBO(max_plate_nesting=1)

    # Set up the inference process.
    svi = SVI(model.model, model.guide, scheduler, loss=loss_function)

    # Run training.
    run_training(model,
                 svi,
                 train_loader,
                 test_loader,
                 epochs=args.epochs,
                 test_freq=5)

    return model
示例#35
0
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)
def main_sVAE(arr):

    X_DIM = 10000
    Y_DIM = 2
    Z_DIM = 16
    ALPHA_ENCO = int("".join(str(i) for i in arr[0:10]), 2)
    BETA_ENCO = int("".join(str(i) for i in arr[10:18]), 2)

    H_DIM_ENCO_1 = ALPHA_ENCO + BETA_ENCO

    H_DIM_ENCO_2 = ALPHA_ENCO

    H_DIM_DECO_1 = ALPHA_ENCO

    H_DIM_DECO_2 = ALPHA_ENCO + BETA_ENCO

    print(str(H_DIM_ENCO_1))
    print(str(H_DIM_ENCO_2))
    print(str(H_DIM_DECO_1))
    print(str(H_DIM_DECO_2))
    print('-----------')

    # Run options
    LEARNING_RATE = 1.0e-3
    USE_CUDA = True

    # Run only for a single iteration for testing
    NUM_EPOCHS = 501
    TEST_FREQUENCY = 5

    train_loader, test_loader = dataloader_first()
    # clear param store
    pyro.clear_param_store()

    # setup the VAE
    vae = VAE(x_dim=X_DIM,
              y_dim=Y_DIM,
              h_dim_enco_1=H_DIM_ENCO_1,
              h_dim_enco_2=H_DIM_ENCO_2,
              h_dim_deco_1=H_DIM_DECO_1,
              h_dim_deco_2=H_DIM_DECO_1,
              z_dim=Z_DIM,
              use_cuda=USE_CUDA)

    # setup the optimizer
    adagrad_params = {"lr": 0.00003}
    optimizer = Adagrad(adagrad_params)

    svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

    train_elbo = []
    test_elbo = []
    # training loop
    for epoch in range(NUM_EPOCHS):
        total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)

        train_elbo.append(-total_epoch_loss_train)

        print("[epoch %03d]  average training loss: %.4f" %
              (epoch, total_epoch_loss_train))

        if epoch == 500:
            # --------------------------Do testing for each epoch here--------------------------------
            # initialize loss accumulator
            test_loss = 0.
            # compute the loss over the entire test set
            for x_test, y_test in test_loader:

                x_test = x_test.cuda()
                y_test = y_test.cuda()
                # compute ELBO estimate and accumulate loss
                labels_y_test = torch.tensor(np.zeros((y_test.shape[0], 2)))
                y_test_2 = torch.Tensor.cpu(
                    y_test.reshape(1,
                                   y_test.size()[0])[0]).numpy().astype(int)
                labels_y_test = np.eye(2)[y_test_2]
                labels_y_test = torch.from_numpy(labels_y_test)

                test_loss += svi.evaluate_loss(
                    x_test.reshape(-1, 10000),
                    labels_y_test.cuda().float()
                )  #Data entry point <---------------------------------Data Entry Point

            normalizer_test = len(test_loader.dataset)
            total_epoch_loss_test = test_loss / normalizer_test
            print("[epoch %03d]  average training loss: %.4f" %
                  (epoch, total_epoch_loss_test))
            return total_epoch_loss_test
def main(args):
    # clear param store
    pyro.clear_param_store()

    # setup MNIST data loaders
    # train_loader, test_loader
    train_loader, test_loader = setup_data_loaders(MNIST,
                                                   use_cuda=args.cuda,
                                                   batch_size=256)

    # setup the VAE
    vae = VAE(use_cuda=args.cuda)

    # setup the optimizer
    adam_args = {"lr": args.learning_rate}
    optimizer = Adam(adam_args)

    # setup the inference algorithm
    elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)

    # setup visdom for visualization
    if args.visdom_flag:
        vis = visdom.Visdom()

    train_elbo = []
    test_elbo = []
    # training loop
    for epoch in range(args.num_epochs):
        # initialize loss accumulator
        epoch_loss = 0.
        # do a training epoch over each mini-batch x returned
        # by the data loader
        for x, _ in train_loader:
            # if on GPU put mini-batch into CUDA memory
            if args.cuda:
                x = x.cuda()
            # do ELBO gradient and accumulate loss
            epoch_loss += svi.step(x)

        # report training diagnostics
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = epoch_loss / normalizer_train
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" %
              (epoch, total_epoch_loss_train))

        if epoch % args.test_frequency == 0:
            # initialize loss accumulator
            test_loss = 0.
            # compute the loss over the entire test set
            for i, (x, _) in enumerate(test_loader):
                # if on GPU put mini-batch into CUDA memory
                if args.cuda:
                    x = x.cuda()
                # compute ELBO estimate and accumulate loss
                test_loss += svi.evaluate_loss(x)

                # pick three random test images from the first mini-batch and
                # visualize how well we're reconstructing them
                if i == 0:
                    if args.visdom_flag:
                        plot_vae_samples(vae, vis)
                        reco_indices = np.random.randint(0, x.shape[0], 3)
                        for index in reco_indices:
                            test_img = x[index, :]
                            reco_img = vae.reconstruct_img(test_img)
                            vis.image(test_img.reshape(
                                28, 28).detach().cpu().numpy(),
                                      opts={'caption': 'test image'})
                            vis.image(reco_img.reshape(
                                28, 28).detach().cpu().numpy(),
                                      opts={'caption': 'reconstructed image'})

            # report test diagnostics
            normalizer_test = len(test_loader.dataset)
            total_epoch_loss_test = test_loss / normalizer_test
            test_elbo.append(total_epoch_loss_test)
            print("[epoch %03d]  average test loss: %.4f" %
                  (epoch, total_epoch_loss_test))

        if epoch == args.tsne_iter:
            mnist_test_tsne(vae=vae, test_loader=test_loader)
            plot_llk(np.array(train_elbo), np.array(test_elbo))

    return vae
def train():
    py.clear_param_store()
    for j in range(num_iterations):
        loss = svi.step(x_data, y_data)
        if j % 100 == 0:
            print("Iteration %04d loss: %4f" % (j + 1, loss / len(data)))
def main():
    pyro.clear_param_store()
    #pyro.get_param_store().load('Pyro_model')
    for j in range(n_epochs):
        loss = 0
        start = time.time()
        for data in train_loader:
            data[0] = Variable(data[0].cuda())  #.view(-1, 28 * 28).cuda())
            data[1] = Variable(data[1].long().cuda())
            loss += svi.step(data)
        print(time.time() - start)
        #if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" %
              (j + 1, loss / float(n_train_batches * batch_size)))
    #for name in pyro.get_param_store().get_all_param_names():
    #    print("[%s]: %.3f" % (name, pyro.param(name).data.numpy()))
    pyro.get_param_store().save('Pyro_model')
    datasets = {'RegularImages_0.0': [test.test_data, test.test_labels]}

    fgsm = glob.glob('fgsm/fgsm_cifar10_examples_x_10000_*'
                     )  #glob.glob('fgsm/fgsm_mnist_adv_x_1000_*')
    fgsm_labels = test.test_labels  #torch.from_numpy(np.argmax(np.load('fgsm/fgsm_mnist_adv_y_1000.npy'), axis=1))
    for file in fgsm:
        parts = file.split('_')
        key = parts[0].split('/')[0] + '_' + parts[-1].split('.npy')[0]

        datasets[key] = [torch.from_numpy(np.load(file)), fgsm_labels]

    #jsma = glob.glob('jsma/jsma_cifar10_adv_x_10000*')
    #jsma_labels = torch.from_numpy(np.argmax(np.load('jsma/jsma_cifar10_adv_y_10000.npy'), axis=1))
    #for file in jsma:
    #    parts = file.split('_')
    #    key = parts[0].split('/')[0] + '_' + parts[-1].split('.npy')[0]
    #    datasets[key] = [torch.from_numpy(np.load(file)), jsma_labels]
    gaussian = glob.glob('gaussian/cifar_gaussian_adv_x_*')
    gaussian_labels = torch.from_numpy(
        np.argmax(np.load('gaussian/cifar_gaussian_adv_y.npy')[0:1000],
                  axis=1))
    for file in gaussian:
        parts = file.split('_')
        key = parts[0].split('/')[0] + '_' + parts[-1].split('.npy')[0]
        datasets[key] = [torch.from_numpy(np.load(file)), gaussian_labels]

    print(datasets.keys())
    print(
        '################################################################################'
    )
    accuracies = {}
    for key, value in datasets.iteritems():
        print(key)
        parts = key.split('_')
        adversary_type = parts[0]
        epsilon = parts[1]
        data = value
        X, y = data[0], data[1]  #.view(-1, 28 * 28), data[1]
        x_data, y_data = Variable(X.float().cuda()), Variable(y.cuda())
        T = 100

        accs = []
        samples = np.zeros((y_data.data.size()[0], T, outputs))
        for i in range(T):
            sampled_model = guide(None)
            pred = sampled_model(x_data)
            samples[:, i, :] = pred.data.cpu().numpy()
            _, out = torch.max(pred, 1)

            acc = np.count_nonzero(
                np.squeeze(out.data.cpu().numpy()) == np.int32(y_data.data.cpu(
                ).numpy().ravel())) / float(y_data.data.size()[0])
            accs.append(acc)

        variationRatio = []
        mutualInformation = []
        predictiveEntropy = []
        predictions = []

        for i in range(0, len(y_data)):
            entry = samples[i, :, :]
            variationRatio.append(Uncertainty.variation_ratio(entry))
            mutualInformation.append(Uncertainty.mutual_information(entry))
            predictiveEntropy.append(Uncertainty.predictive_entropy(entry))
            predictions.append(np.max(entry.mean(axis=0), axis=0))

        uncertainty = {}
        uncertainty['varation_ratio'] = np.array(variationRatio)
        uncertainty['predictive_entropy'] = np.array(predictiveEntropy)
        uncertainty['mutual_information'] = np.array(mutualInformation)
        predictions = np.array(predictions)

        Uncertainty.plot_uncertainty(uncertainty,
                                     predictions,
                                     adversarial_type=adversary_type,
                                     epsilon=float(epsilon),
                                     directory='Results_CIFAR10_PYRO')
        #, directory='Results_CIFAR10_PYRO')

        accs = np.array(accs)
        print('Accuracy mean: {}, Accuracy std: {}'.format(
            accs.mean(), accs.std()))
        #accuracies[key] = {'mean': accs.mean(), 'std': accs.std()}
        accuracies[key] = {'mean': accs.mean(), 'std': accs.std(),  \
                   'variationratio': [uncertainty['varation_ratio'].mean(), uncertainty['varation_ratio'].std()], \
                   'predictiveEntropy': [uncertainty['predictive_entropy'].mean(), uncertainty['predictive_entropy'].std()], \
                   'mutualInformation': [uncertainty['mutual_information'].mean(), uncertainty['mutual_information'].std()]}

    np.save('PyroBNN_accuracies_CIFAR10', accuracies)
示例#40
0
def setup():
    pyro.clear_param_store()
示例#41
0
get_ipython().run_cell_magic(
    'time', '',
    '\n### HMC ###\npyro.clear_param_store()\n\n# Set random seed for reproducibility.\npyro.set_rng_seed(2)\n\n# Set up HMC sampler.\nkernel = HMC(gpc, step_size=0.05, trajectory_length=1, \n             adapt_step_size=False, adapt_mass_matrix=False, jit_compile=True)\nhmc = MCMC(kernel, num_samples=500, warmup_steps=500)\nhmc.run(X, y.double())\n\n# Get posterior samples\nhmc_posterior_samples = hmc.get_samples()'
)

# In[59]:

plot_uq(hmc_posterior_samples, X, Xnew, "HMC")

# ## NUTS

# In[60]:

## NUTS ###
pyro.clear_param_store()
pyro.set_rng_seed(2)

nuts = MCMC(NUTS(gpc,
                 target_accept_prob=0.8,
                 max_tree_depth=10,
                 jit_compile=True),
            num_samples=500,
            warmup_steps=500)
nuts.run(X, y.double())

nuts_posterior_samples = nuts.get_samples()

# In[61]:

plot_uq(nuts_posterior_samples, X, Xnew, "NUTS")
示例#42
0
    def do_fit_prior_test(self, reparameterized, n_steps, loss, debug=False):
        pyro.clear_param_store()
        Beta = dist.Beta if reparameterized else fakes.NonreparameterizedBeta

        def model():
            with pyro.plate("samples", self.sample_batch_size):
                pyro.sample(
                    "p_latent",
                    Beta(
                        torch.stack([torch.stack([self.alpha0])] *
                                    self.sample_batch_size),
                        torch.stack([torch.stack([self.beta0])] *
                                    self.sample_batch_size),
                    ).to_event(1),
                )

        def guide():
            alpha_q_log = pyro.param("alpha_q_log",
                                     torch.log(self.alpha0) + 0.17)
            beta_q_log = pyro.param("beta_q_log",
                                    torch.log(self.beta0) - 0.143)
            alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
            with pyro.plate("samples", self.sample_batch_size):
                pyro.sample(
                    "p_latent",
                    Beta(
                        torch.stack([torch.stack([alpha_q])] *
                                    self.sample_batch_size),
                        torch.stack([torch.stack([beta_q])] *
                                    self.sample_batch_size),
                    ).to_event(1),
                )

        adam = optim.Adam({"lr": 0.001, "betas": (0.97, 0.999)})
        svi = SVI(model, guide, adam, loss=loss)

        alpha = 0.99
        for k in range(n_steps):
            svi.step()
            if debug:
                alpha_error = param_abs_error("alpha_q_log",
                                              torch.log(self.alpha0))
                beta_error = param_abs_error("beta_q_log",
                                             torch.log(self.beta0))
                with torch.no_grad():
                    if k == 0:
                        (
                            avg_loglikelihood,
                            avg_penalty,
                        ) = loss._differentiable_loss_parts(model, guide)
                        avg_loglikelihood = torch_item(avg_loglikelihood)
                        avg_penalty = torch_item(avg_penalty)
                    loglikelihood, penalty = loss._differentiable_loss_parts(
                        model, guide)
                    avg_loglikelihood = alpha * avg_loglikelihood + (
                        1 - alpha) * torch_item(loglikelihood)
                    avg_penalty = alpha * avg_penalty + (
                        1 - alpha) * torch_item(penalty)
                if k % 100 == 0:
                    print(alpha_error, beta_error)
                    print(avg_loglikelihood, avg_penalty)
                    print()

        alpha_error = param_abs_error("alpha_q_log", torch.log(self.alpha0))
        beta_error = param_abs_error("beta_q_log", torch.log(self.beta0))
        assert_equal(0.0, alpha_error, prec=0.08)
        assert_equal(0.0, beta_error, prec=0.08)
示例#43
0
def test_exponential_gamma(gamma_dist, n_steps, elbo_impl):
    pyro.clear_param_store()

    # gamma prior hyperparameter
    alpha0 = torch.tensor(1.0)
    # gamma prior hyperparameter
    beta0 = torch.tensor(1.0)
    n_data = 2
    data = torch.tensor([3.0, 2.0])  # two observations
    alpha_n = alpha0 + torch.tensor(float(n_data))  # posterior alpha
    beta_n = beta0 + torch.sum(data)  # posterior beta
    prec = 0.2 if gamma_dist.has_rsample else 0.25

    def model(alpha0, beta0, alpha_n, beta_n):
        lambda_latent = pyro.sample("lambda_latent", gamma_dist(alpha0, beta0))
        with pyro.plate("data", n_data):
            pyro.sample("obs", dist.Exponential(lambda_latent), obs=data)
        return lambda_latent

    def guide(alpha0, beta0, alpha_n, beta_n):
        alpha_q = pyro.param("alpha_q",
                             alpha_n * math.exp(0.17),
                             constraint=constraints.positive)
        beta_q = pyro.param("beta_q",
                            beta_n / math.exp(0.143),
                            constraint=constraints.positive)
        pyro.sample("lambda_latent", gamma_dist(alpha_q, beta_q))

    adam = optim.Adam({"lr": 0.0003, "betas": (0.97, 0.999)})
    if elbo_impl is RenyiELBO:
        elbo = elbo_impl(
            alpha=0.2,
            num_particles=3,
            max_plate_nesting=1,
            strict_enumeration_warning=False,
        )
    elif elbo_impl is ReweightedWakeSleep:
        if gamma_dist is ShapeAugmentedGamma:
            pytest.xfail(
                reason=
                "ShapeAugmentedGamma not suported for ReweightedWakeSleep")
        else:
            elbo = elbo_impl(num_particles=3,
                             max_plate_nesting=1,
                             strict_enumeration_warning=False)
    else:
        elbo = elbo_impl(max_plate_nesting=1, strict_enumeration_warning=False)
    svi = SVI(model, guide, adam, loss=elbo)

    with xfail_if_not_implemented():
        for k in range(n_steps):
            svi.step(alpha0, beta0, alpha_n, beta_n)

    assert_equal(
        pyro.param("alpha_q"),
        alpha_n,
        prec=prec,
        msg="{} vs {}".format(
            pyro.param("alpha_q").detach().cpu().numpy(),
            alpha_n.detach().cpu().numpy()),
    )
    assert_equal(
        pyro.param("beta_q"),
        beta_n,
        prec=prec,
        msg="{} vs {}".format(
            pyro.param("beta_q").detach().cpu().numpy(),
            beta_n.detach().cpu().numpy()),
    )
示例#44
0
    def do_fit_prior_test(self,
                          reparameterized,
                          n_steps,
                          loss,
                          debug=False,
                          lr=0.0002):
        pyro.clear_param_store()
        Gamma = dist.Gamma if reparameterized else fakes.NonreparameterizedGamma

        def model():
            with pyro.plate("samples", self.sample_batch_size):
                pyro.sample(
                    "lambda_latent",
                    Gamma(
                        torch.stack([torch.stack([self.alpha0])] *
                                    self.sample_batch_size),
                        torch.stack([torch.stack([self.beta0])] *
                                    self.sample_batch_size),
                    ).to_event(1),
                )

        def guide():
            alpha_q = pyro.param(
                "alpha_q",
                self.alpha0.detach() + math.exp(0.17),
                constraint=constraints.positive,
            )
            beta_q = pyro.param(
                "beta_q",
                self.beta0.detach() / math.exp(0.143),
                constraint=constraints.positive,
            )
            with pyro.plate("samples", self.sample_batch_size):
                pyro.sample(
                    "lambda_latent",
                    Gamma(
                        torch.stack([torch.stack([alpha_q])] *
                                    self.sample_batch_size),
                        torch.stack([torch.stack([beta_q])] *
                                    self.sample_batch_size),
                    ).to_event(1),
                )

        adam = optim.Adam({"lr": lr, "betas": (0.97, 0.999)})
        svi = SVI(model, guide, adam, loss)

        alpha = 0.99
        for k in range(n_steps):
            svi.step()
            if debug:
                alpha_error = param_mse("alpha_q", self.alpha0)
                beta_error = param_mse("beta_q", self.beta0)
                with torch.no_grad():
                    if k == 0:
                        (
                            avg_loglikelihood,
                            avg_penalty,
                        ) = loss._differentiable_loss_parts(
                            model, guide, (), {})
                        avg_loglikelihood = torch_item(avg_loglikelihood)
                        avg_penalty = torch_item(avg_penalty)
                    loglikelihood, penalty = loss._differentiable_loss_parts(
                        model, guide, (), {})
                    avg_loglikelihood = alpha * avg_loglikelihood + (
                        1 - alpha) * torch_item(loglikelihood)
                    avg_penalty = alpha * avg_penalty + (
                        1 - alpha) * torch_item(penalty)
                if k % 100 == 0:
                    print(alpha_error, beta_error)
                    print(avg_loglikelihood, avg_penalty)
                    print()

        assert_equal(
            pyro.param("alpha_q"),
            self.alpha0,
            prec=0.2,
            msg="{} vs {}".format(
                pyro.param("alpha_q").detach().cpu().numpy(),
                self.alpha0.detach().cpu().numpy(),
            ),
        )
        assert_equal(
            pyro.param("beta_q"),
            self.beta0,
            prec=0.15,
            msg="{} vs {}".format(
                pyro.param("beta_q").detach().cpu().numpy(),
                self.beta0.detach().cpu().numpy(),
            ),
        )
示例#45
0
    def do_elbo_test(
        self,
        repa1,
        repa2,
        n_steps,
        prec,
        lr,
        use_nn_baseline,
        use_decaying_avg_baseline,
    ):
        logger.info(" - - - - - DO NORMALNORMALNORMAL ELBO TEST - - - - - -")
        logger.info(
            "[reparameterized = %s, %s; nn_baseline = %s, decaying_baseline = %s]"
            % (repa1, repa2, use_nn_baseline, use_decaying_avg_baseline))
        pyro.clear_param_store()
        Normal1 = dist.Normal if repa1 else fakes.NonreparameterizedNormal
        Normal2 = dist.Normal if repa2 else fakes.NonreparameterizedNormal

        if use_nn_baseline:

            class VanillaBaselineNN(nn.Module):
                def __init__(self, dim_input, dim_h):
                    super().__init__()
                    self.lin1 = nn.Linear(dim_input, dim_h)
                    self.lin2 = nn.Linear(dim_h, 2)
                    self.sigmoid = nn.Sigmoid()

                def forward(self, x):
                    h = self.sigmoid(self.lin1(x))
                    return self.lin2(h)

            loc_prime_baseline = pyro.module("loc_prime_baseline",
                                             VanillaBaselineNN(2, 5))
        else:
            loc_prime_baseline = None

        def model():
            with pyro.plate("plate", 2):
                loc_latent_prime = pyro.sample(
                    "loc_latent_prime",
                    Normal1(self.loc0, torch.pow(self.lam0, -0.5)))
                loc_latent = pyro.sample(
                    "loc_latent",
                    Normal2(loc_latent_prime, torch.pow(self.lam0, -0.5)))
                with pyro.plate("data", len(self.data)):
                    pyro.sample(
                        "obs",
                        dist.Normal(loc_latent, torch.pow(
                            self.lam, -0.5)).expand_by(self.data.shape[:1]),
                        obs=self.data,
                    )
            return loc_latent

        # note that the exact posterior is not mean field!
        def guide():
            loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.334)
            log_sig_q = pyro.param("log_sig_q",
                                   self.analytic_log_sig_n.expand(2) - 0.29)
            loc_q_prime = pyro.param("loc_q_prime", torch.tensor([-0.34,
                                                                  0.52]))
            kappa_q = pyro.param("kappa_q", torch.tensor([0.74]))
            log_sig_q_prime = pyro.param("log_sig_q_prime",
                                         -0.5 * torch.log(1.2 * self.lam0))
            sig_q, sig_q_prime = torch.exp(log_sig_q), torch.exp(
                log_sig_q_prime)
            with pyro.plate("plate", 2):
                loc_latent = pyro.sample(
                    "loc_latent",
                    Normal2(loc_q, sig_q),
                    infer=dict(baseline=dict(
                        use_decaying_avg_baseline=use_decaying_avg_baseline)),
                )
                pyro.sample(
                    "loc_latent_prime",
                    Normal1(
                        kappa_q.expand_as(loc_latent) * loc_latent +
                        loc_q_prime,
                        sig_q_prime,
                    ),
                    infer=dict(baseline=dict(
                        nn_baseline=loc_prime_baseline,
                        nn_baseline_input=loc_latent,
                        use_decaying_avg_baseline=use_decaying_avg_baseline,
                    )),
                )
                with pyro.plate("data", len(self.data)):
                    pass

            return loc_latent

        adam = optim.Adam({"lr": 0.0015, "betas": (0.97, 0.999)})
        svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())

        for k in range(n_steps):
            svi.step()

            loc_error = param_mse("loc_q", self.analytic_loc_n)
            log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n)
            loc_prime_error = param_mse("loc_q_prime", 0.5 * self.loc0)
            kappa_error = param_mse("kappa_q", 0.5 * torch.ones(1))
            log_sig_prime_error = param_mse("log_sig_q_prime",
                                            -0.5 * torch.log(2.0 * self.lam0))

            if k % 500 == 0:
                logger.debug("errors:  %.4f, %.4f" %
                             (loc_error, log_sig_error))
                logger.debug(", %.4f, %.4f" %
                             (loc_prime_error, log_sig_prime_error))
                logger.debug(", %.4f" % kappa_error)

        assert_equal(0.0, loc_error, prec=prec)
        assert_equal(0.0, log_sig_error, prec=prec)
        assert_equal(0.0, loc_prime_error, prec=prec)
        assert_equal(0.0, log_sig_prime_error, prec=prec)
        assert_equal(0.0, kappa_error, prec=prec)
示例#46
0
    def update_noise_svi(self, obs_data, intervened_model=None):
        """
        Use svi to find out the mu, sigma of the distributionsfor the 
        condition outlined in obs_data
        """
        
        def guide(noise):
            """
            The guide serves as an approximation to the posterior p(z|x). 
            The guide provides a valid joint probability density over all the 
            latent random variables in the model.
            
            https://pyro.ai/examples/svi_part_i.html
            """
            # create params with constraints
            mu = {
                'N_X': pyro.param('N_X_mu', 0.5*torch.ones(self.image_dim),constraint = constraints.interval(0., 1.)),
                'N_Z': pyro.param('N_Z_mu', torch.zeros(self.z_dim),constraint = constraints.interval(-3., 3.)),
                'N_Y_1': pyro.param('N_Y_1_mu', 0.5*torch.ones(self.label_dims[1]),constraint = constraints.interval(0., 1.)),
                'N_Y_2': pyro.param('N_Y_2_mu', 0.5*torch.ones(self.label_dims[2]),constraint = constraints.interval(0., 1.)),
                'N_Y_3': pyro.param('N_Y_3_mu', 0.5*torch.ones(self.label_dims[3]),constraint = constraints.interval(0., 1.)),
                'N_Y_4': pyro.param('N_Y_4_mu', 0.5*torch.ones(self.label_dims[4]),constraint = constraints.interval(0., 1.)),
                'N_Y_5': pyro.param('N_Y_5_mu', 0.5*torch.ones(self.label_dims[5]),constraint = constraints.interval(0., 1.))
                }
            sigma = {
                'N_X': pyro.param('N_X_sigma', 0.1*torch.ones(self.image_dim),constraint = constraints.interval(0.0001, 0.5)),
                'N_Z': pyro.param('N_Z_sigma', torch.ones(self.z_dim),constraint = constraints.interval(0.0001, 3.)),
                'N_Y_1': pyro.param('N_Y_1_sigma', 0.1*torch.ones(self.label_dims[1]),constraint = constraints.interval(0.0001, 0.5)),
                'N_Y_2': pyro.param('N_Y_2_sigma', 0.1*torch.ones(self.label_dims[2]),constraint = constraints.interval(0.0001, 0.5)),
                'N_Y_3': pyro.param('N_Y_3_sigma', 0.1*torch.ones(self.label_dims[3]),constraint = constraints.interval(0.0001, 0.5)),
                'N_Y_4': pyro.param('N_Y_4_sigma', 0.1*torch.ones(self.label_dims[4]),constraint = constraints.interval(0.0001, 0.5)),
                'N_Y_5': pyro.param('N_Y_5_sigma', 0.1*torch.ones(self.label_dims[5]),constraint = constraints.interval(0.0001, 0.5))
                }
            for noise_term in noise.keys():
                pyro.sample(noise_term, dist.Normal(mu[noise_term], sigma[noise_term]).to_event(1))
        
        # Condition the model
        if intervened_model is not None:
          obs_model = pyro.condition(intervened_model, obs_data)
        else:
          obs_model = pyro.condition(self.model, obs_data)
          
        pyro.clear_param_store()

        # Once we’ve specified a guide, we’re ready to proceed to inference. 
        # Now, this an optimization problem where each iteration of training takes  
        # a step that moves the guide closer to the exact posterior 
        
        # https://arxiv.org/pdf/1601.00670.pdf
        svi = SVI(
            model= obs_model,
            guide= guide,
            optim= SGD({"lr": 1e-5, 'momentum': 0.1}),
            loss=Trace_ELBO(retain_graph=True)
        )
        
        num_steps = 1500
        samples = defaultdict(list)
        for t in range(num_steps):
            loss = svi.step(self.init_noise)
#             if t % 100 == 0:
#                 print("step %d: loss of %.2f" % (t, loss))
            for noise in self.init_noise.keys():
                mu = '{}_mu'.format(noise)
                sigma = '{}_sigma'.format(noise)
                samples[mu].append(pyro.param(mu).detach().numpy())
                samples[sigma].append(pyro.param(sigma).detach().numpy())
        means = {k: torch.tensor(np.array(v).mean(axis=0)) for k, v in samples.items()}
        
        # update the inferred noise
        updated_noise = {
            'N_X'  : dist.Normal(means['N_X_mu'], means['N_X_sigma']),
            'N_Z'  : dist.Normal(means['N_Z_mu'], means['N_Z_sigma']),
            'N_Y_1': dist.Normal(means['N_Y_1_mu'], means['N_Y_1_sigma']),
            'N_Y_2': dist.Normal(means['N_Y_2_mu'], means['N_Y_2_sigma']),
            'N_Y_3': dist.Normal(means['N_Y_3_mu'], means['N_Y_3_sigma']),
            'N_Y_4': dist.Normal(means['N_Y_4_mu'], means['N_Y_4_sigma']),
            'N_Y_5': dist.Normal(means['N_Y_5_mu'], means['N_Y_5_sigma']),
        }
        return updated_noise
示例#47
0
    def _test_plate_in_elbo(self,
                            n_superfluous_top,
                            n_superfluous_bottom,
                            n_steps,
                            lr=0.0012):
        pyro.clear_param_store()
        self.data_tensor = torch.zeros(9, 2)
        for _out in range(self.n_outer):
            for _in in range(self.n_inner):
                self.data_tensor[3 * _out + _in, :] = self.data[_out][_in]
        self.data_as_list = [
            self.data_tensor[0:4, :],
            self.data_tensor[4:7, :],
            self.data_tensor[7:9, :],
        ]

        def model():
            loc_latent = pyro.sample(
                "loc_latent",
                fakes.NonreparameterizedNormal(self.loc0,
                                               torch.pow(self.lam0,
                                                         -0.5)).to_event(1),
            )

            for i in pyro.plate("outer", 3):
                x_i = self.data_as_list[i]
                with pyro.plate("inner_%d" % i, x_i.size(0)):
                    for k in range(n_superfluous_top):
                        z_i_k = pyro.sample(
                            "z_%d_%d" % (i, k),
                            fakes.NonreparameterizedNormal(0, 1).expand_by(
                                [4 - i]),
                        )
                        assert z_i_k.shape == (4 - i, )
                    obs_i = pyro.sample(
                        "obs_%d" % i,
                        dist.Normal(loc_latent, torch.pow(self.lam,
                                                          -0.5)).to_event(1),
                        obs=x_i,
                    )
                    assert obs_i.shape == (4 - i, 2)
                    for k in range(n_superfluous_top,
                                   n_superfluous_top + n_superfluous_bottom):
                        z_i_k = pyro.sample(
                            "z_%d_%d" % (i, k),
                            fakes.NonreparameterizedNormal(0, 1).expand_by(
                                [4 - i]),
                        )
                        assert z_i_k.shape == (4 - i, )

        pt_loc_baseline = torch.nn.Linear(1, 1)
        pt_superfluous_baselines = []
        for k in range(n_superfluous_top + n_superfluous_bottom):
            pt_superfluous_baselines.extend([
                torch.nn.Linear(2, 4),
                torch.nn.Linear(2, 3),
                torch.nn.Linear(2, 2)
            ])

        def guide():
            loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.094)
            log_sig_q = pyro.param("log_sig_q",
                                   self.analytic_log_sig_n.expand(2) - 0.07)
            sig_q = torch.exp(log_sig_q)
            trivial_baseline = pyro.module("loc_baseline", pt_loc_baseline)
            baseline_value = trivial_baseline(torch.ones(1)).squeeze()
            loc_latent = pyro.sample(
                "loc_latent",
                fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1),
                infer=dict(baseline=dict(baseline_value=baseline_value)),
            )

            for i in pyro.plate("outer", 3):
                with pyro.plate("inner_%d" % i, 4 - i):
                    for k in range(n_superfluous_top + n_superfluous_bottom):
                        z_baseline = pyro.module(
                            "z_baseline_%d_%d" % (i, k),
                            pt_superfluous_baselines[3 * k + i],
                        )
                        baseline_value = z_baseline(loc_latent.detach())
                        mean_i = pyro.param("mean_%d_%d" % (i, k),
                                            0.5 * torch.ones(4 - i))
                        z_i_k = pyro.sample(
                            "z_%d_%d" % (i, k),
                            fakes.NonreparameterizedNormal(mean_i, 1),
                            infer=dict(baseline=dict(
                                baseline_value=baseline_value)),
                        )
                        assert z_i_k.shape == (4 - i, )

        def per_param_callable(param_name):
            if "baseline" in param_name:
                return {"lr": 0.010, "betas": (0.95, 0.999)}
            else:
                return {"lr": lr, "betas": (0.95, 0.999)}

        adam = optim.Adam(per_param_callable)
        svi = SVI(model, guide, adam, loss=TraceGraph_ELBO())

        for step in range(n_steps):
            svi.step()

            loc_error = param_abs_error("loc_q", self.analytic_loc_n)
            log_sig_error = param_abs_error("log_sig_q",
                                            self.analytic_log_sig_n)

            if n_superfluous_top > 0 or n_superfluous_bottom > 0:
                superfluous_errors = []
                for k in range(n_superfluous_top + n_superfluous_bottom):
                    mean_0_error = torch.sum(
                        torch.pow(pyro.param("mean_0_%d" % k), 2.0))
                    mean_1_error = torch.sum(
                        torch.pow(pyro.param("mean_1_%d" % k), 2.0))
                    mean_2_error = torch.sum(
                        torch.pow(pyro.param("mean_2_%d" % k), 2.0))
                    superfluous_error = torch.max(
                        torch.max(mean_0_error, mean_1_error), mean_2_error)
                    superfluous_errors.append(
                        superfluous_error.detach().cpu().numpy())

            if step % 500 == 0:
                logger.debug("loc error, log(scale) error:  %.4f, %.4f" %
                             (loc_error, log_sig_error))
                if n_superfluous_top > 0 or n_superfluous_bottom > 0:
                    logger.debug("superfluous error: %.4f" %
                                 np.max(superfluous_errors))

        assert_equal(0.0, loc_error, prec=0.04)
        assert_equal(0.0, log_sig_error, prec=0.05)
        if n_superfluous_top > 0 or n_superfluous_bottom > 0:
            assert_equal(0.0, np.max(superfluous_errors), prec=0.04)
示例#48
0
文件: train.py 项目: bahammel/bayes
def test_model(model, guide, loss, x_data, y_data):
    pyro.clear_param_store()
    loss.loss(model, guide, x_data, y_data)
"""""
# Run options
LEARNING_RATE = 1.0e-3
USE_CUDA = False

# Run only for a single iteration for testing
NUM_EPOCHS = 1 if smoke_test else 100
TEST_FREQUENCY = 5

# Get data
train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA)
pyro.clear_param_store()

# Initialize instance of the VAE class
vae = VAE()

# Setup Adam optimizer (an algorithm for first-order gradient-based optimization)
optimizer = Adam({"lr": 1.0e-3})

# SVI: stochastic variational inference - a scalable algorithm for approximating posterior distributions.
# Trace_ELBO: top-level interface for stochastic variational inference via optimization of the evidence lower bound.
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())


#set_trace()

train_elbo = []
test_elbo = []

# Training loop
for epoch in range(NUM_EPOCHS):
    print("1")
    total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
    #print("2")
    train_elbo.append(-total_epoch_loss_train)
    print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

    if epoch % TEST_FREQUENCY == 0:
        # report test diagnostics
        total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))
""" ""
示例#50
0
文件: conftest.py 项目: ludvb/xfuse
def pytest_runtest_setup(item):
    # pylint: disable=missing-function-docstring
    pyro.clear_param_store()
    reset_state()
    if item.get_closest_marker("fix_rng") is not None:
        torch.manual_seed(0)
示例#51
0
def main(args):
    # clear param store
    pyro.clear_param_store()

    ### SETUP
    train_loader, test_loader = get_data()

    # setup the VAE
    vae = VAE(use_cuda=args.cuda)

    # setup the optimizer
    adam_args = {"lr": args.learning_rate}
    optimizer = Adam(adam_args)

    # setup the inference algorithm
    elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)

    inputSize = 0

    # setup visdom for visualization
    if args.visdom_flag:
        vis = visdom.Visdom()

    train_elbo = []
    test_elbo = []

    for epoch in range(args.num_epochs):
        # initialize loss accumulator
        epoch_loss = 0.
        # do a training epoch over each mini-batch x returned
        # by the data loader

        for step, batch in enumerate(train_loader):
            x, adj = 0, 0
            # if on GPU put mini-batch into CUDA memory
            if args.cuda:
                x = batch['x'].cuda()
                adj = batch['edge_index'].cuda()
            else:

                x = batch['x']
                adj = batch['edge_index']
            print("x_shape", x.shape)
            print("adj_shape", adj.shape)

            inputSize = x.shape[0] * x.shape[1]
            epoch_loss += svi.step(x, adj)

        # report training diagnostics
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = epoch_loss / normalizer_train
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" %
              (epoch, total_epoch_loss_train))

        if True:
            # if epoch % args.test_frequency == 0:
            # initialize loss accumulator
            test_loss = 0.
            # compute the loss over the entire test set
            for step, batch in enumerate(test_loader):
                x, adj = 0, 0
                # if on GPU put mini-batch into CUDA memory
                if args.cuda:
                    x = batch['x'].cuda()
                    adj = batch['edge_index'].cuda()
                else:
                    x = batch['x']
                    adj = batch['edge_index']

                # compute ELBO estimate and accumulate loss
                # print('before evaluating test loss')
                test_loss += svi.evaluate_loss(x, adj)
                # print('after evaluating test loss')

                # pick three random test images from the first mini-batch and
                # visualize how well we're reconstructing them
                # if i == 0:
                #     if args.visdom_flag:
                #         plot_vae_samples(vae, vis)
                #         reco_indices = np.random.randint(0, x.shape[0], 3)
                #         for index in reco_indices:
                #             test_img = x[index, :]
                #             reco_img = vae.reconstruct_img(test_img)
                #             vis.image(test_img.reshape(28, 28).detach().cpu().numpy(),
                #                       opts={'caption': 'test image'})
                #             vis.image(reco_img.reshape(28, 28).detach().cpu().numpy(),
                #                       opts={'caption': 'reconstructed image'})

                if args.visdom_flag:
                    plot_vae_samples(vae, vis)
                    reco_indices = np.random.randint(0, x.shape[0], 3)
                    for index in reco_indices:
                        test_img = x[index, :]
                        reco_img = vae.reconstruct_graph(test_img)
                        vis.image(test_img.reshape(28,
                                                   28).detach().cpu().numpy(),
                                  opts={'caption': 'test image'})
                        vis.image(reco_img.reshape(28,
                                                   28).detach().cpu().numpy(),
                                  opts={'caption': 'reconstructed image'})

            # report test diagnostics
            normalizer_test = len(test_loader.dataset)
            total_epoch_loss_test = test_loss / normalizer_test
            test_elbo.append(total_epoch_loss_test)
            print("[epoch %03d]  average test loss: %.4f" %
                  (epoch, total_epoch_loss_test))

        # if epoch == args.tsne_iter:
        #     mnist_test_tsne(vae=vae, test_loader=test_loader)
        #     plot_llk(np.array(train_elbo), np.array(test_elbo))

    if args.save:
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': vae.state_dict(),
                'optimzier_state_dict': optimizer.get_state(),
                'train_loss': total_epoch_loss_train,
                'test_loss': total_epoch_loss_test
            }, 'vae_' + args.name + str(args.time) + '.pt')

    return vae
def rejection_sample_feasible_tree(num_attempts=999):
    ''' Repeatedly samples trees from the grammar until
    one satisfies some hand-coded constraints.

    This will be simplified when constraint specification
    and sampling machinery is generalized. For now, this is
    hard-coded to work for the kitchen example. '''

    for attempt_k in range(num_attempts):
        start = time.time()
        pyro.clear_param_store()
        scene_tree = ParseTree.generate_from_root_type(root_node_type=Kitchen)
        end = time.time()

        print("Generated tree in %f seconds." % (end - start))

        # Enforce  that there are > 0 cabinets
        num_cabinets = len([node for node in scene_tree.nodes if isinstance(node, Cabinet)])
        if num_cabinets != 1:
            continue
        
        # Enforce that there are at least 2 objects on the table
        tables = scene_tree.find_nodes_by_type(Table)
        table_children = sum([scene_tree.get_recursive_children_of_node(node) for node in tables], [])
        num_objects_on_tables = len([node for node in table_children if isinstance(node, KitchenObject)])
        print("Num objs on table: ", num_objects_on_tables)
        if num_objects_on_tables < 5:
            continue

        # Enforce that there are at least 2 objects in cabinets
        #cabinets = scene_tree.find_nodes_by_type(Cabinet)
        #table_children = sum([scene_tree.get_recursive_children_of_node(node) for node in cabinets], [])
        #num_objects_in_cabinets = len([node for node in table_children if isinstance(node, KitchenObject)])
        #print("Num objs in cabinets: ", num_objects_in_cabinets)
        #if num_objects_in_cabinets < 2:
        #    continue

        
        # Do Collision checking on the clearance geometry, and reject
        # scenes where the collision geometry is in collision.
        # (This could be done at subtree level, and eventually I'll do that --
        # but for this scene it doesn't matter b/c clearance geometry is all
        # furniture level anyway.
        # TODO: What if I did rejection sampling for nonpenetration at the
        # container level? Is that legit as a sampling strategy?)
        builder_clearance, mbp_clearance, sg_clearance = \
            compile_scene_tree_clearance_geometry_to_mbp_and_sg(scene_tree)
        mbp_clearance.Finalize()
        diagram_clearance = builder_clearance.Build()
        diagram_context = diagram_clearance.CreateDefaultContext()
        mbp_context = diagram_clearance.GetMutableSubsystemContext(mbp_clearance, diagram_context)
        constraint = build_clearance_nonpenetration_constraint(
            mbp_clearance, mbp_context, -0.01)
        constraint.Eval(mbp_clearance.GetPositions(mbp_context))

        q0 = mbp_clearance.GetPositions(mbp_context)
        print("CONSTRAINT EVAL: %f <= %f <= %f" % (
              constraint.lower_bound(),
              constraint.Eval(mbp_clearance.GetPositions(mbp_context)),
              constraint.upper_bound()))
        print(len(get_collisions(mbp_clearance, mbp_context)), " bodies in collision")

        # We can draw clearance geometry for debugging.
        #draw_clearance_geometry_meshcat(scene_tree, alpha=0.3)

        # If we failed the initial clearance check, resample.
        if not constraint.CheckSatisfied(q0):
            continue
            
        # Good solution!
        return scene_tree, True

    # Bad solution :(
    return scene_tree, False
示例#53
0
def train(num_iterations, svi):
    pyro.clear_param_store()
    for j in tqdm(range(num_iterations)):
        loss = svi.step(data)
        losses.append(loss)
def train(args, DATA_PATH):
    # clear param store
    pyro.clear_param_store()
    #pyro.enable_validation(True)

    # train_loader, test_loader
    transform = {}
    transform["train"] = transforms.Compose([
        transforms.Resize((400, 400)),
        transforms.ToTensor(),
    ])
    transform["test"] = transforms.Compose(
        [transforms.Resize((400, 400)),
         transforms.ToTensor()])

    train_loader, test_loader = setup_data_loaders(
        dataset=GameCharacterFullData,
        root_path=DATA_PATH,
        batch_size=32,
        transforms=transform)

    # setup the VAE
    vae = VAE(use_cuda=args.cuda, num_labels=17)

    # setup the exponential learning rate scheduler
    optimizer = torch.optim.Adam
    scheduler = pyro.optim.ExponentialLR({
        'optimizer': optimizer,
        'optim_args': {
            'lr': args.learning_rate
        },
        'gamma': 0.1
    })

    # setup the inference algorithm
    elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    svi = SVI(vae.model, vae.guide, scheduler, loss=elbo)

    # setup visdom for visualization
    if args.visdom_flag:
        vis = visdom.Visdom(port='8097')

    train_elbo = []
    test_elbo = []
    # training loop
    for epoch in range(args.num_epochs):
        # initialize loss accumulator
        epoch_loss = 0.
        # do a training epoch over each mini-batch x returned
        # by the data loader
        for x, y, actor, reactor, actor_type, reactor_type, action, reaction in train_loader:
            # if on GPU put mini-batch into CUDA memory
            if args.cuda:
                x = x.cuda()
                y = y.cuda()
                actor = actor.cuda()
                reactor = reactor.cuda()
                actor_type = actor_type.cuda()
                reactor_type = reactor_type.cuda()
                action = action.cuda()
                reaction = reaction.cuda()
            # do ELBO gradient and accumulate loss
            epoch_loss += svi.step(x, y, actor, reactor, actor_type,
                                   reactor_type, action, reaction)

        # report training diagnostics
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = epoch_loss / normalizer_train
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" %
              (epoch, total_epoch_loss_train))

        if epoch % args.test_frequency == 0:
            # initialize loss accumulator
            test_loss = 0.
            # compute the loss over the entire test set
            for i, (x, y, actor, reactor, actor_type, reactor_type, action,
                    reaction) in enumerate(test_loader):
                # if on GPU put mini-batch into CUDA memory
                if args.cuda:
                    x = x.cuda()
                    y = y.cuda()
                    actor = actor.cuda()
                    reactor = reactor.cuda()
                    actor_type = actor_type.cuda()
                    reactor_type = reactor_type.cuda()
                    action = action.cuda()
                    reaction = reaction.cuda()
                # compute ELBO estimate and accumulate loss
                test_loss += svi.evaluate_loss(x, y, actor, reactor,
                                               actor_type, reactor_type,
                                               action, reaction)
                # pick three random test images from the first mini-batch and
                # visualize how well we're reconstructing them
                if i == 0:
                    if args.visdom_flag:
                        plot_vae_samples(vae, vis)
                        reco_indices = np.random.randint(0, x.shape[0], 3)
                        for index in reco_indices:
                            test_img = x[index, :]
                            reco_img = vae.reconstruct_img(test_img)
                            vis.image(test_img.reshape(
                                400, 400).detach().cpu().numpy(),
                                      opts={'caption': 'test image'})
                            vis.image(reco_img.reshape(
                                400, 400).detach().cpu().numpy(),
                                      opts={'caption': 'reconstructed image'})
            # report test diagnostics
            normalizer_test = len(test_loader.dataset)
            total_epoch_loss_test = test_loss / normalizer_test
            test_elbo.append(total_epoch_loss_test)
            print("[epoch %03d]  average test loss: %.4f" %
                  (epoch, total_epoch_loss_test))

    return vae, optimizer
示例#55
0
    def setUp(self):
        pyro.clear_param_store()

        def loc1_prior(tensor, *args, **kwargs):
            flat_tensor = tensor.view(-1)
            m = torch.zeros(flat_tensor.size(0))
            s = torch.ones(flat_tensor.size(0))
            return Normal(m, s).sample().view(tensor.size())

        def scale1_prior(tensor, *args, **kwargs):
            flat_tensor = tensor.view(-1)
            m = torch.zeros(flat_tensor.size(0))
            s = torch.ones(flat_tensor.size(0))
            return Normal(m, s).sample().view(tensor.size()).exp()

        def loc2_prior(tensor, *args, **kwargs):
            flat_tensor = tensor.view(-1)
            m = torch.zeros(flat_tensor.size(0))
            return Bernoulli(m).sample().view(tensor.size())

        def scale2_prior(tensor, *args, **kwargs):
            return scale1_prior(tensor)

        def bias_prior(tensor, *args, **kwargs):
            return loc2_prior(tensor)

        def weight_prior(tensor, *args, **kwargs):
            return scale1_prior(tensor)

        def stoch_fn(tensor, *args, **kwargs):
            loc = torch.zeros(tensor.size())
            scale = torch.ones(tensor.size())
            return pyro.sample("sample", Normal(loc, scale))

        def guide():
            loc1 = pyro.param("loc1", torch.randn(2, requires_grad=True))
            scale1 = pyro.param("scale1", torch.ones(2, requires_grad=True))
            pyro.sample("latent1", Normal(loc1, scale1))

            loc2 = pyro.param("loc2", torch.randn(2, requires_grad=True))
            scale2 = pyro.param("scale2", torch.ones(2, requires_grad=True))
            latent2 = pyro.sample("latent2", Normal(loc2, scale2))
            return latent2

        def dup_param_guide():
            a = pyro.param("loc")
            b = pyro.param("loc")
            assert a == b

        self.model = Model()
        self.guide = guide
        self.dup_param_guide = dup_param_guide
        self.prior = scale1_prior
        self.prior_dict = {
            "loc1": loc1_prior,
            "scale1": scale1_prior,
            "loc2": loc2_prior,
            "scale2": scale2_prior
        }
        self.partial_dict = {"loc1": loc1_prior, "scale1": scale1_prior}
        self.nn_prior = {"fc.bias": bias_prior, "fc.weight": weight_prior}
        self.fn = stoch_fn
        self.data = torch.randn(2, 2)
示例#56
0
def svi(data,
        assets,
        iter,
        num_samples,
        seed,
        autoguide=None,
        optim=None,
        subsample_size=None):
    assert type(data) == dict
    assert type(assets) == Assets
    assert type(iter) == int
    assert type(num_samples) == int
    assert seed is None or type(seed) == int
    assert autoguide is None or callable(autoguide)

    N = next(data.values().__iter__()).shape[0]
    assert all(arr.shape[0] == N for arr in data.values())
    assert (subsample_size is None
            or type(subsample_size) == int and 0 < subsample_size < N)

    # TODO: Fix that this interface doesn't work for
    # `AutoLaplaceApproximation`, which requires different functions
    # to be used for optimisation / collecting samples.
    autoguide = AutoMultivariateNormal if autoguide is None else autoguide
    optim = Adam({'lr': 1e-3}) if optim is None else optim

    guide = autoguide(assets.fn)
    svi = SVI(assets.fn, guide, optim, loss=Trace_ELBO())
    pyro.clear_param_store()

    t0 = time.time()
    max_iter_str_width = len(str(iter))
    max_out_len = 0

    with seed_ctx_mgr(seed):

        for i in range(iter):
            if subsample_size is None:
                dfN = None
                subsample = None
                data_for_step = data
            else:
                dfN = N
                subsample = torch.randint(0, N, (subsample_size, )).long()
                data_for_step = {
                    k: get_mini_batch(arr, subsample)
                    for k, arr in data.items()
                }
            loss = svi.step(dfN=dfN, subsample=subsample, **data_for_step)
            t1 = time.time()
            if t1 - t0 > 0.5 or (i + 1) == iter:
                iter_str = str(i + 1).rjust(max_iter_str_width)
                out = 'iter: {} | loss: {:.3f}'.format(iter_str, loss)
                max_out_len = max(max_out_len, len(out))
                # Sending the ANSI code to clear the line doesn't seem to
                # work in notebooks, so instead we pad the output with
                # enough spaces to ensure we overwrite all previous input.
                print('\r{}'.format(out.ljust(max_out_len)),
                      end='',
                      file=stderr)
                t0 = t1
        print(file=stderr)

        # We run the guide to generate traces from the (approx.)
        # posterior. We also run the model against those traces in order
        # to compute transformed parameters, such as `b`, etc.
        def get_model_trace():
            guide_tr = poutine.trace(guide).get_trace()
            model_tr = poutine.trace(poutine.replay(
                assets.fn, trace=guide_tr)).get_trace(mode='prior_only',
                                                      **data)
            return model_tr

        # Represent the posterior as a bunch of samples, ignoring the
        # possibility that we might plausibly be able to figure out e.g.
        # posterior maginals from the variational parameters.
        samples = [get_model_trace() for _ in range(num_samples)]

    # Unlike the NUTS case, we don't eagerly compute `mu` (for the
    # data set used for inference) when building `Samples#raw_samples`.
    # (This is because it's possible that N is very large since we
    # support subsampling.) Therefore `loc` always computes `mu` from
    # the data and the samples here.
    def loc(d):
        return location(assets.fn, samples, d)

    return Samples(samples, partial(get_param, samples), loc)
示例#57
0
def test():
    parser = argparse.ArgumentParser(description='Train VAE.')
    parser.add_argument('-c', '--config', help='Config file.')
    args = parser.parse_args()
    print(args)
    c = json.load(open(args.config))
    print(c)

    # clear param store
    pyro.clear_param_store()

    # batch_size = 64
    # root_dir = r'D:\projects\trading\mlbootcamp\tickers'
    # series_length = 60
    lookback = 50  # 160
    input_dim = 1

    test_start_date = datetime.strptime(
        c['test_start_date'], '%Y/%m/%d') if c['test_start_date'] else None
    test_end_date = datetime.strptime(
        c['test_end_date'], '%Y/%m/%d') if c['test_end_date'] else None
    min_sequence_length_test = 2 * (c['series_length'] + lookback)
    max_n_files = None

    out_path = Path(c['out_dir'])
    out_path.mkdir(exist_ok=True)

    # load_path = 'out_saved/checkpoint_0035.pt'

    dataset_test = create_ticker_dataset(
        c['in_dir'],
        c['series_length'],
        lookback,
        min_sequence_length_test,
        start_date=test_start_date,
        end_date=test_end_date,
        fixed_start_date=True,
        normalised_returns=c['normalised_returns'],
        max_n_files=max_n_files)
    test_loader = DataLoader(dataset_test,
                             batch_size=c['batch_size'],
                             shuffle=False,
                             num_workers=0,
                             drop_last=True)

    # N_train_data = len(dataset_train)
    N_test_data = len(dataset_test)
    # N_mini_batches = N_train_data // c['batch_size']
    # N_train_time_slices = c['batch_size'] * N_mini_batches

    print(f'N_test_data: {N_test_data}')

    # setup the VAE
    vae = VAE(c['series_length'], z_dim=c['z_dim'], use_cuda=c['cuda'])

    # setup the optimizer
    # adam_args = {"lr": args.learning_rate}
    # optimizer = Adam(adam_args)

    # setup the inference algorithm
    # elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    # svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)

    if c['checkpoint_load']:
        checkpoint = torch.load(c['checkpoint_load'])
        vae.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if 1:
        find_similar(vae, test_loader, c['cuda'])

    # Visualise first batch.
    batch = next(iter(test_loader))
    x = batch['series']
    if c['cuda']:
        x = x.cuda()
    x = x.float()
    x_reconst = vae.reconstruct_img(x)
    x = x.cpu().numpy()
    x_reconst = x_reconst.cpu().detach().numpy()

    n = min(5, x.shape[0])
    fig, axes = plt.subplots(n, 1, squeeze=False)
    for s in range(n):
        ax = axes[s, 0]
        ax.plot(x[s])
        ax.plot(x_reconst[s])
    fig.savefig(out_path / f'test_batch.png')
示例#58
0
    return site_stats


# Prepare training data
df = rugged_data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
train = torch.tensor(df.values, dtype=torch.float)

svi = SVI(model,
          guide,
          optim.Adam({"lr": .005}),
          loss=Trace_ELBO(),
          num_samples=1000)
is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]
pyro.clear_param_store()
num_iters = 8000 if not smoke_test else 2
for i in range(num_iters):
    elbo = svi.step(is_cont_africa, ruggedness, log_gdp)
    if i % 500 == 0:
        logging.info("Elbo loss: {}".format(elbo))

posterior = svi.run(log_gdp, is_cont_africa, ruggedness)

sites = ["a", "bA", "bR", "bAR", "sigma"]

for site, values in summary(posterior, sites).items():
    print("Site: {}".format(site))
    print(values, "\n")

示例#59
0
文件: hmm.py 项目: pyro-ppl/pyro
def main(args):
    if args.cuda:
        torch.set_default_tensor_type("torch.cuda.FloatTensor")

    logging.info("Loading data")
    data = poly.load_data(poly.JSB_CHORALES)

    logging.info("-" * 40)
    model = models[args.model]
    logging.info("Training {} on {} sequences".format(
        model.__name__, len(data["train"]["sequences"])))
    sequences = data["train"]["sequences"]
    lengths = data["train"]["sequence_lengths"]

    # find all the notes that are present at least once in the training set
    present_notes = (sequences == 1).sum(0).sum(0) > 0
    # remove notes that are never played (we remove 37/88 notes)
    sequences = sequences[..., present_notes]

    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
        sequences = sequences[:, :args.truncate]
    num_observations = float(lengths.sum())
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()

    # We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
    # out the hidden state x. This is accomplished via an automatic guide that
    # learns point estimates of all of our conditional probability tables,
    # named probs_*.
    guide = AutoDelta(
        poutine.block(model,
                      expose_fn=lambda msg: msg["name"].startswith("probs_")))

    # To help debug our tensor shapes, let's print the shape of each site's
    # distribution, value, and log_prob tensor. Note this information is
    # automatically printed on most errors inside SVI.
    if args.print_shapes:
        first_available_dim = -2 if model is model_0 else -3
        guide_trace = poutine.trace(guide).get_trace(
            sequences, lengths, args=args, batch_size=args.batch_size)
        model_trace = poutine.trace(
            poutine.replay(poutine.enum(model, first_available_dim),
                           guide_trace)).get_trace(sequences,
                                                   lengths,
                                                   args=args,
                                                   batch_size=args.batch_size)
        logging.info(model_trace.format_shapes())

    # Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
    # All of our models have two plates: "data" and "tones".
    optim = Adam({"lr": args.learning_rate})
    if args.tmc:
        if args.jit:
            raise NotImplementedError(
                "jit support not yet added for TraceTMC_ELBO")
        elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2)
        tmc_model = poutine.infer_config(
            model,
            lambda msg: {
                "num_samples": args.tmc_num_samples,
                "expand": False
            } if msg["infer"].get("enumerate", None) == "parallel" else {},
        )  # noqa: E501
        svi = SVI(tmc_model, guide, optim, elbo)
    else:
        Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
        elbo = Elbo(
            max_plate_nesting=1 if model is model_0 else 2,
            strict_enumeration_warning=(model is not model_7),
            jit_options={"time_compilation": args.time_compilation},
        )
        svi = SVI(model, guide, optim, elbo)

    # We'll train on small minibatches.
    logging.info("Step\tLoss")
    for step in range(args.num_steps):
        loss = svi.step(sequences,
                        lengths,
                        args=args,
                        batch_size=args.batch_size)
        logging.info("{: >5d}\t{}".format(step, loss / num_observations))

    if args.jit and args.time_compilation:
        logging.debug("time to compile: {} s.".format(
            elbo._differentiable_loss.compile_time))

    # We evaluate on the entire training dataset,
    # excluding the prior term so our results are comparable across models.
    train_loss = elbo.loss(model,
                           guide,
                           sequences,
                           lengths,
                           args,
                           include_prior=False)
    logging.info("training loss = {}".format(train_loss / num_observations))

    # Finally we evaluate on the test dataset.
    logging.info("-" * 40)
    logging.info("Evaluating on {} test sequences".format(
        len(data["test"]["sequences"])))
    sequences = data["test"]["sequences"][..., present_notes]
    lengths = data["test"]["sequence_lengths"]
    if args.truncate:
        lengths = lengths.clamp(max=args.truncate)
    num_observations = float(lengths.sum())

    # note that since we removed unseen notes above (to make the problem a bit easier and for
    # numerical stability) this test loss may not be directly comparable to numbers
    # reported on this dataset elsewhere.
    test_loss = elbo.loss(model,
                          guide,
                          sequences,
                          lengths,
                          args=args,
                          include_prior=False)
    logging.info("test loss = {}".format(test_loss / num_observations))

    # We expect models with higher capacity to perform better,
    # but eventually overfit to the training set.
    capacity = sum(
        value.reshape(-1).size(0) for value in pyro.get_param_store().values())
    logging.info("{} capacity = {} parameters".format(model.__name__,
                                                      capacity))
示例#60
0
文件: inference.py 项目: sg5g10/VBODE
def run_inference(data, gen_model, ode_model, method, iterations = 10000, num_particles = 1, \
     num_samples = 1000, warmup_steps = 500, init_scale = 0.1, \
         seed = 12, lr = 0.5, return_sites = ("_RETURN")):

    torch_data = torch.tensor(data, dtype=torch.float)
    if isinstance(ode_model,ForwardSensManualJacobians) or \
        isinstance(ode_model,ForwardSensTorchJacobians):
        ode_op = ForwardSensOp
    elif isinstance(ode_model,AdjointSensManualJacobians) or \
        isinstance(ode_model,AdjointSensTorchJacobians):
        ode_op = AdjointSensOp
    else:
        raise ValueError(
            'Unknown sensitivity solver: Use "Forward" or "Adjoint"')
    model = gen_model(ode_op, ode_model)
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    if method == 'VI':

        guide = AutoMultivariateNormal(model, init_scale=init_scale)
        optim = AdagradRMSProp({"eta": lr})
        if num_particles == 1:
            svi = SVI(model, guide, optim, loss=Trace_ELBO())
        else:
            svi = SVI(model,
                      guide,
                      optim,
                      loss=Trace_ELBO(num_particles=num_particles,
                                      vectorize_particles=True))
        loss_trace = []
        t0 = timer.time()
        for j in range(iterations):
            loss = svi.step(torch_data)
            loss_trace.append(loss)

            if j % 500 == 0:
                print("[iteration %04d] loss: %.4f" %
                      (j + 1, np.mean(loss_trace[max(0, j - 1000):j + 1])))
        t1 = timer.time()
        print('VI time: ', t1 - t0)
        predictive = Predictive(
            model,
            guide=guide,
            num_samples=num_samples,
            return_sites=return_sites)  #"ode_params", "scale",
        vb_samples = predictive(torch_data)
        return vb_samples

    elif method == 'NUTS':

        nuts_kernel = NUTS(model,
                           adapt_step_size=True,
                           init_strategy=init_to_median)

        mcmc = MCMC(nuts_kernel,
                    num_samples=iterations,
                    warmup_steps=warmup_steps,
                    num_chains=2)
        t0 = timer.time()
        mcmc.run(torch_data)
        t1 = timer.time()
        print('NUTS time: ', t1 - t0)
        hmc_samples = {
            k: v.detach().cpu().numpy()
            for k, v in mcmc.get_samples().items()
        }
        return hmc_samples
    else:
        raise ValueError('Unknown method: Use "NUTS" or "VI"')