def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.094) log_sig_q = pyro.param("log_sig_q", self.analytic_log_sig_n.expand(2) - 0.07) sig_q = torch.exp(log_sig_q) trivial_baseline = pyro.module("loc_baseline", pt_loc_baseline) baseline_value = trivial_baseline(torch.ones(1)).squeeze() loc_latent = pyro.sample( "loc_latent", fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1), infer=dict(baseline=dict(baseline_value=baseline_value)), ) for i in pyro.plate("outer", 3): with pyro.plate("inner_%d" % i, 4 - i): for k in range(n_superfluous_top + n_superfluous_bottom): z_baseline = pyro.module( "z_baseline_%d_%d" % (i, k), pt_superfluous_baselines[3 * k + i], ) baseline_value = z_baseline(loc_latent.detach()) mean_i = pyro.param("mean_%d_%d" % (i, k), 0.5 * torch.ones(4 - i)) z_i_k = pyro.sample( "z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(mean_i, 1), infer=dict(baseline=dict( baseline_value=baseline_value)), ) assert z_i_k.shape == (4 - i, )
def model(): loc_latent = pyro.sample( "loc_latent", fakes.NonreparameterizedNormal(self.loc0, torch.pow(self.lam0, -0.5)).to_event(1), ) for i in pyro.plate("outer", 3): x_i = self.data_as_list[i] with pyro.plate("inner_%d" % i, x_i.size(0)): for k in range(n_superfluous_top): z_i_k = pyro.sample( "z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(0, 1).expand_by( [4 - i]), ) assert z_i_k.shape == (4 - i, ) obs_i = pyro.sample( "obs_%d" % i, dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).to_event(1), obs=x_i, ) assert obs_i.shape == (4 - i, 2) for k in range(n_superfluous_top, n_superfluous_top + n_superfluous_bottom): z_i_k = pyro.sample( "z_%d_%d" % (i, k), fakes.NonreparameterizedNormal(0, 1).expand_by( [4 - i]), ) assert z_i_k.shape == (4 - i, )
def model(): loc_latent = pyro.sample("loc_latent", fakes.NonreparameterizedNormal(self.loc0, torch.pow(self.lam0, -0.5)) .to_event(1)) for i in pyro.plate("outer", self.n_outer): for j in pyro.plate("inner_%d" % i, self.n_inner): pyro.sample("obs_%d_%d" % (i, j), dist.Normal(loc_latent, torch.pow(self.lam, -0.5)).to_event(1), obs=self.data[i][j])
def guide(): loc_q = pyro.param("loc_q", self.analytic_loc_n.expand(2) + 0.234) log_sig_q = pyro.param("log_sig_q", self.analytic_log_sig_n.expand(2) - 0.27) sig_q = torch.exp(log_sig_q) pyro.sample("loc_latent", fakes.NonreparameterizedNormal(loc_q, sig_q).to_event(1), infer=dict(baseline=dict(use_decaying_avg_baseline=True))) for i in pyro.plate("outer", self.n_outer): for j in pyro.plate("inner_%d" % i, self.n_inner): pass
def guide(): loc_q = pyro.param( "loc_q", torch.tensor(self.analytic_loc_n.expand(2) + 0.234, requires_grad=True)) log_sig_q = pyro.param( "log_sig_q", torch.tensor(self.analytic_log_sig_n.expand(2) - 0.27, requires_grad=True)) sig_q = torch.exp(log_sig_q) pyro.sample( "loc_latent", fakes.NonreparameterizedNormal(loc_q, sig_q).independent(1), infer=dict(baseline=dict(use_decaying_avg_baseline=True))) for i in pyro.irange("outer", self.n_outer): for j in pyro.irange("inner_%d" % i, self.n_inner): pass