Exemplo n.º 1
0
 def guide():
     mu_q_log = pyro.param(
         "mu_q_log",
         Variable(self.log_mu_n.data + 0.17, requires_grad=True))
     tau_q_log = pyro.param(
         "tau_q_log",
         Variable(self.log_tau_n.data - 0.143, requires_grad=True))
     mu_q, tau_q = torch.exp(mu_q_log), torch.exp(tau_q_log)
     q_dist = DiagNormal(mu_q, torch.pow(tau_q, -0.5))
     q_dist.reparametrized = reparametrized
     pyro.sample("mu_latent", q_dist)
Exemplo n.º 2
0
 def guide():
     mu_q = pyro.param(
         "mu_q",
         Variable(self.analytic_mu_n.data + 0.134 * torch.ones(2),
                  requires_grad=True))
     log_sig_q = pyro.param(
         "log_sig_q",
         Variable(self.analytic_log_sig_n.data - 0.09 * torch.ones(2),
                  requires_grad=True))
     sig_q = torch.exp(log_sig_q)
     q_dist = DiagNormal(mu_q, sig_q)
     q_dist.reparametrized = reparametrized
     pyro.sample("mu_latent", q_dist)
     pyro.map_data(self.data, lambda i, x: None, batch_size=1)