コード例 #1
0
ファイル: test_tracegraph_elbo.py プロジェクト: pyro-ppl/pyro
        def guide():
            loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.094)
            log_sig_q = pyro.param("log_sig_q",
                                   self.analytic_log_sig_n.expand(2) - 0.07)
            sig_q = torch.exp(log_sig_q)
            trivial_baseline = pyro.module("loc_baseline", pt_loc_baseline)
            baseline_value = trivial_baseline(torch.ones(1)).squeeze()
            loc_latent = pyro.sample(
                "loc_latent",
                fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1),
                infer=dict(baseline=dict(baseline_value=baseline_value)),
            )

            for i in pyro.plate("outer", 3):
                with pyro.plate("inner_%d" % i, 4 - i):
                    for k in range(n_superfluous_top + n_superfluous_bottom):
                        z_baseline = pyro.module(
                            "z_baseline_%d_%d" % (i, k),
                            pt_superfluous_baselines[3 * k + i],
                        )
                        baseline_value = z_baseline(loc_latent.detach())
                        mean_i = pyro.param("mean_%d_%d" % (i, k),
                                            0.5 * torch.ones(4 - i))
                        z_i_k = pyro.sample(
                            "z_%d_%d" % (i, k),
                            fakes.NonreparameterizedNormal(mean_i, 1),
                            infer=dict(baseline=dict(
                                baseline_value=baseline_value)),
                        )
                        assert z_i_k.shape == (4 - i, )
コード例 #2
0
ファイル: test_tracegraph_elbo.py プロジェクト: pyro-ppl/pyro
        def model():
            loc_latent = pyro.sample(
                "loc_latent",
                fakes.NonreparameterizedNormal(self.loc0,
                                               torch.pow(self.lam0,
                                                         -0.5)).to_event(1),
            )

            for i in pyro.plate("outer", 3):
                x_i = self.data_as_list[i]
                with pyro.plate("inner_%d" % i, x_i.size(0)):
                    for k in range(n_superfluous_top):
                        z_i_k = pyro.sample(
                            "z_%d_%d" % (i, k),
                            fakes.NonreparameterizedNormal(0, 1).expand_by(
                                [4 - i]),
                        )
                        assert z_i_k.shape == (4 - i, )
                    obs_i = pyro.sample(
                        "obs_%d" % i,
                        dist.Normal(loc_latent, torch.pow(self.lam,
                                                          -0.5)).to_event(1),
                        obs=x_i,
                    )
                    assert obs_i.shape == (4 - i, 2)
                    for k in range(n_superfluous_top,
                                   n_superfluous_top + n_superfluous_bottom):
                        z_i_k = pyro.sample(
                            "z_%d_%d" % (i, k),
                            fakes.NonreparameterizedNormal(0, 1).expand_by(
                                [4 - i]),
                        )
                        assert z_i_k.shape == (4 - i, )
コード例 #3
0
ファイル: test_tracegraph_elbo.py プロジェクト: zyxue/pyro
 def model():
     loc_latent = pyro.sample("loc_latent",
                              fakes.NonreparameterizedNormal(self.loc0, torch.pow(self.lam0, -0.5))
                                   .to_event(1))
     for i in pyro.plate("outer", self.n_outer):
         for j in pyro.plate("inner_%d" % i, self.n_inner):
             pyro.sample("obs_%d_%d" % (i, j),
                         dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).to_event(1),
                         obs=self.data[i][j])
コード例 #4
0
ファイル: test_tracegraph_elbo.py プロジェクト: zyxue/pyro
        def guide():
            loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.234)
            log_sig_q = pyro.param("log_sig_q",
                                   self.analytic_log_sig_n.expand(2) - 0.27)
            sig_q = torch.exp(log_sig_q)
            pyro.sample("loc_latent", fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1),
                        infer=dict(baseline=dict(use_decaying_avg_baseline=True)))

            for i in pyro.plate("outer", self.n_outer):
                for j in pyro.plate("inner_%d" % i, self.n_inner):
                    pass
コード例 #5
0
        def guide():
            loc_q = pyro.param(
                "loc_q",
                torch.tensor(self.analytic_loc_n.expand(2) + 0.234,
                             requires_grad=True))
            log_sig_q = pyro.param(
                "log_sig_q",
                torch.tensor(self.analytic_log_sig_n.expand(2) - 0.27,
                             requires_grad=True))
            sig_q = torch.exp(log_sig_q)
            pyro.sample(
                "loc_latent",
                fakes.NonreparameterizedNormal(loc_q, sig_q).independent(1),
                infer=dict(baseline=dict(use_decaying_avg_baseline=True)))

            for i in pyro.irange("outer", self.n_outer):
                for j in pyro.irange("inner_%d" % i, self.n_inner):
                    pass