Exemple #1
0
def iarange_model(subsample_size):
    loc = torch.zeros(20)
    scale = torch.ones(20)
    with pyro.iarange('iarange', 20, subsample_size) as batch:
        pyro.sample("x", dist.Normal(loc[batch], scale[batch]))
        result = list(batch.data)
    return result
Exemple #2
0
 def guide(subsample_size):
     mu = pyro.param("mu", lambda: Variable(torch.zeros(len(data)), requires_grad=True))
     sigma = pyro.param("sigma", lambda: Variable(torch.ones(1), requires_grad=True))
     with pyro.iarange("data", len(data), subsample_size) as ind:
         mu = mu[ind]
         sigma = sigma.expand(subsample_size)
         pyro.sample("z", dist.Normal(mu, sigma, reparameterized=reparameterized))
        def guide():
            mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.334 * torch.ones(2),
                                               requires_grad=True))
            log_sig_q = pyro.param("log_sig_q", Variable(
                                   self.analytic_log_sig_n.data - 0.29 * torch.ones(2),
                                   requires_grad=True))
            mu_q_prime = pyro.param("mu_q_prime", Variable(torch.Tensor([-0.34, 0.52]),
                                    requires_grad=True))
            kappa_q = pyro.param("kappa_q", Variable(torch.Tensor([0.74]),
                                 requires_grad=True))
            log_sig_q_prime = pyro.param("log_sig_q_prime",
                                         Variable(-0.5 * torch.log(1.2 * self.lam0.data),
                                                  requires_grad=True))
            sig_q, sig_q_prime = torch.exp(log_sig_q), torch.exp(log_sig_q_prime)
            mu_latent_dist = dist.Normal(mu_q, sig_q, reparameterized=repa2)
            mu_latent = pyro.sample("mu_latent", mu_latent_dist,
                                    baseline=dict(use_decaying_avg_baseline=use_decaying_avg_baseline))
            mu_latent_prime_dist = dist.Normal(kappa_q.expand_as(mu_latent) * mu_latent + mu_q_prime,
                                               sig_q_prime,
                                               reparameterized=repa1)
            pyro.sample("mu_latent_prime",
                        mu_latent_prime_dist,
                        baseline=dict(nn_baseline=mu_prime_baseline,
                                      nn_baseline_input=mu_latent,
                                      use_decaying_avg_baseline=use_decaying_avg_baseline))

            return mu_latent
Exemple #4
0
 def guide(self, x):
     # register PyTorch module `encoder` with Pyro
     pyro.module("encoder", self.encoder)
     # use the encoder to get the parameters used to define q(z|x)
     z_mu, z_sigma = self.encoder.forward(x)
     # sample the latent code z
     pyro.sample("latent", dist.normal, z_mu, z_sigma)
Exemple #5
0
    def model(self):
        self.set_mode("model")

        Xu = self.get_param("Xu")
        u_loc = self.get_param("u_loc")
        u_scale_tril = self.get_param("u_scale_tril")

        M = Xu.shape[0]
        Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M, M)) * self.jitter
        Luu = Kuu.potrf(upper=False)

        zero_loc = Xu.new_zeros(u_loc.shape)
        u_name = param_with_module_name(self.name, "u")
        if self.whiten:
            Id = torch.eye(M, out=Xu.new_empty(M, M))
            pyro.sample(u_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Id)
                            .independent(zero_loc.dim() - 1))
        else:
            pyro.sample(u_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Luu)
                            .independent(zero_loc.dim() - 1))

        f_loc, f_var = conditional(self.X, Xu, self.kernel, u_loc, u_scale_tril,
                                   Luu, full_cov=False, whiten=self.whiten,
                                   jitter=self.jitter)

        f_loc = f_loc + self.mean_function(self.X)
        if self.y is None:
            return f_loc, f_var
        else:
            with poutine.scale(None, self.num_data / self.X.shape[0]):
                return self.likelihood(f_loc, f_var, self.y)
Exemple #6
0
 def guide(subsample):
     loc = pyro.param("loc", lambda: torch.zeros(len(data), requires_grad=True))
     scale = pyro.param("scale", lambda: torch.tensor([1.0], requires_grad=True))
     with pyro.iarange("particles", num_particles):
         with pyro.iarange("data", len(data), subsample_size, subsample) as ind:
             loc_ind = loc[ind].unsqueeze(-1).expand(-1, num_particles)
             pyro.sample("z", Normal(loc_ind, scale))
Exemple #7
0
 def guide():
     p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
     outer_irange = pyro.irange("irange_0", 3, subsample_size)
     inner_irange = pyro.irange("irange_1", 3, subsample_size)
     for j in inner_irange:
         for i in outer_irange:
             pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))
Exemple #8
0
 def model():
     p = torch.tensor(0.5)
     outer_irange = pyro.irange("irange_0", 3, subsample_size)
     inner_irange = pyro.irange("irange_1", 3, subsample_size)
     for i in outer_irange:
         for j in inner_irange:
             pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))
 def guide():
     pyro.module("mymodule", pt_guide)
     mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log)
     sigma = torch.pow(tau_q, -0.5)
     pyro.sample("mu_latent",
                 dist.Normal(mu_q, sigma, reparameterized=reparameterized),
                 baseline=dict(use_decaying_avg_baseline=True))
Exemple #10
0
 def sample_ws(name, width):
     alpha_w_q = pyro.param("log_alpha_w_q_%s" % name,
                            lambda: rand_tensor((width), self.alpha_init, self.sigma_init))
     mean_w_q = pyro.param("log_mean_w_q_%s" % name,
                           lambda: rand_tensor((width), self.mean_init, self.sigma_init))
     alpha_w_q, mean_w_q = self.softplus(alpha_w_q), self.softplus(mean_w_q)
     pyro.sample("w_%s" % name, Gamma(alpha_w_q, alpha_w_q / mean_w_q))
Exemple #11
0
 def model(self, data):
     loc = self.loc_0
     lambda_prec = self.lambda_prec
     for i in range(1, self.chain_len + 1):
         loc = pyro.sample('loc_{}'.format(i),
                           dist.Normal(loc=loc, scale=lambda_prec))
     pyro.sample('obs', dist.Normal(loc, lambda_prec), obs=data)
Exemple #12
0
    def _register_param(self, param, mode="model"):
        """
        Registers a parameter to Pyro. It can be seen as a wrapper for
        :func:`pyro.param` and :func:`pyro.sample` primitives.

        :param str param: Name of the parameter.
        :param str mode: Either "model" or "guide".
        """
        if param in self._fixed_params:
            self._registered_params[param] = self._fixed_params[param]
            return
        prior = self._priors.get(param)
        if self.name is None:
            param_name = param
        else:
            param_name = param_with_module_name(self.name, param)

        if prior is None:
            constraint = self._constraints.get(param)
            default_value = getattr(self, param)
            if constraint is None:
                p = pyro.param(param_name, default_value)
            else:
                p = pyro.param(param_name, default_value, constraint=constraint)
        elif mode == "model":
            p = pyro.sample(param_name, prior)
        else:  # prior != None and mode = "guide"
            MAP_param_name = param_name + "_MAP"
            # TODO: consider to init parameter from a prior call instead of mean
            MAP_param = pyro.param(MAP_param_name, prior.mean.detach())
            p = pyro.sample(param_name, dist.Delta(MAP_param))

        self._registered_params[param] = p
Exemple #13
0
 def guide():
     q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True))
     q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True))
     with pyro.iarange("particles", num_particles):
         y = pyro.sample("y", dist.Bernoulli(q1).expand_by([num_particles]), infer={"enumerate": enumerate1})
         if include_z:
             pyro.sample("z", dist.Normal(q2 * y + 0.10, 1.0))
Exemple #14
0
    def model(self):
        self.set_mode("model")

        f_loc = self.get_param("f_loc")
        f_scale_tril = self.get_param("f_scale_tril")

        N = self.X.shape[0]
        Kff = self.kernel(self.X) + (torch.eye(N, out=self.X.new_empty(N, N)) *
                                     self.jitter)
        Lff = Kff.potrf(upper=False)

        zero_loc = self.X.new_zeros(f_loc.shape)
        f_name = param_with_module_name(self.name, "f")

        if self.whiten:
            Id = torch.eye(N, out=self.X.new_empty(N, N))
            pyro.sample(f_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Id)
                            .independent(zero_loc.dim() - 1))
            f_scale_tril = Lff.matmul(f_scale_tril)
        else:
            pyro.sample(f_name,
                        dist.MultivariateNormal(zero_loc, scale_tril=Lff)
                            .independent(zero_loc.dim() - 1))

        f_var = f_scale_tril.pow(2).sum(dim=-1)

        if self.whiten:
            f_loc = Lff.matmul(f_loc.unsqueeze(-1)).squeeze(-1)
        f_loc = f_loc + self.mean_function(self.X)
        if self.y is None:
            return f_loc, f_var
        else:
            return self.likelihood(f_loc, f_var, self.y)
Exemple #15
0
 def model(num_particles):
     with pyro.iarange("particles", num_particles):
         q3 = pyro.param("q3", torch.tensor(pi3, requires_grad=True))
         q4 = pyro.param("q4", torch.tensor(0.5 * (pi1 + pi2), requires_grad=True))
         z = pyro.sample("z", dist.Normal(q3, 1.0).expand_by([num_particles]))
         zz = torch.exp(z) / (1.0 + torch.exp(z))
         pyro.sample("y", dist.Bernoulli(q4 * zz))
Exemple #16
0
 def guide(num_particles):
     q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True))
     q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True))
     with pyro.iarange("particles", num_particles):
         z = pyro.sample("z", dist.Normal(q2, 1.0).expand_by([num_particles]))
         zz = torch.exp(z) / (1.0 + torch.exp(z))
         pyro.sample("y", dist.Bernoulli(q1 * zz))
Exemple #17
0
 def sample_zs(name, width):
     alpha_z_q = pyro.param("log_alpha_z_q_%s" % name,
                            lambda: rand_tensor((x_size, width), self.alpha_init, self.sigma_init))
     mean_z_q = pyro.param("log_mean_z_q_%s" % name,
                           lambda: rand_tensor((x_size, width), self.mean_init, self.sigma_init))
     alpha_z_q, mean_z_q = self.softplus(alpha_z_q), self.softplus(mean_z_q)
     pyro.sample("z_%s" % name, Gamma(alpha_z_q, alpha_z_q / mean_z_q).independent(1))
Exemple #18
0
    def guide(self, xs, ys=None):
        """
        The guide corresponds to the following:
        q(y|x) = categorical(alpha(x))              # infer digit from an image
        q(z|x,y) = normal(mu(x,y),sigma(x,y))       # infer handwriting style from an image and the digit
        mu, sigma are given by a neural network `encoder_z`
        alpha is given by a neural network `encoder_y`
        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.iarange("independent"):

            # if the class label (the digit) is not supervised, sample
            # (and score) the digit with the variational distribution
            # q(y|x) = categorical(alpha(x))
            if ys is None:
                alpha = self.encoder_y.forward(xs)
                ys = pyro.sample("y", dist.categorical, alpha)

            # sample (and score) the latent handwriting-style with the variational
            # distribution q(z|x,y) = normal(mu(x,y),sigma(x,y))
            mu, sigma = self.encoder_z.forward([xs, ys])
            zs = pyro.sample("z", dist.normal, mu, sigma)   # noqa: F841
Exemple #19
0
 def model():
     with pyro.iarange("num_particles", 10, dim=-3):
         with pyro.iarange("components", 2, dim=-1):
             p = pyro.sample("p", dist.Beta(torch.tensor(1.1), torch.tensor(1.1)))
             assert p.shape == torch.Size((10, 1, 2))
         with pyro.iarange("data", data.shape[0], dim=-2):
             pyro.sample("obs", dist.Bernoulli(p), obs=data)
Exemple #20
0
    def guide_step(self, t, n, prev, inputs):

        rnn_input = torch.cat((inputs['embed'], prev.z_where, prev.z_what, prev.z_pres), 1)
        h, c = self.rnn(rnn_input, (prev.h, prev.c))
        z_pres_p, z_where_loc, z_where_scale = self.predict(h)

        # Compute baseline estimates for discrete choice z_pres.
        bl_value, bl_h, bl_c = self.baseline_step(prev, inputs)

        # Sample presence.
        z_pres = pyro.sample('z_pres_{}'.format(t),
                             dist.Bernoulli(z_pres_p * prev.z_pres).independent(1),
                             infer=dict(baseline=dict(baseline_value=bl_value.squeeze(-1))))

        sample_mask = z_pres if self.use_masking else torch.tensor(1.0)

        z_where = pyro.sample('z_where_{}'.format(t),
                              dist.Normal(z_where_loc + self.z_where_loc_prior,
                                          z_where_scale * self.z_where_scale_prior)
                                  .mask(sample_mask)
                                  .independent(1))

        # Figure 2 of [1] shows x_att depending on z_where and h,
        # rather than z_where and x as here, but I think this is
        # correct.
        x_att = image_to_window(z_where, self.window_size, self.x_size, inputs['raw'])

        # Encode attention windows.
        z_what_loc, z_what_scale = self.encode(x_att)

        z_what = pyro.sample('z_what_{}'.format(t),
                             dist.Normal(z_what_loc, z_what_scale)
                                 .mask(sample_mask)
                                 .independent(1))
        return GuideState(h=h, c=c, bl_h=bl_h, bl_c=bl_c, z_pres=z_pres, z_where=z_where, z_what=z_what)
Exemple #21
0
 def model():
     p2 = torch.tensor(torch.ones(2) / 2)
     p3 = torch.tensor(torch.ones(3) / 3)
     x2 = pyro.sample("x2", dist.OneHotCategorical(p2))
     x3 = pyro.sample("x3", dist.OneHotCategorical(p3))
     assert x2.shape == torch.Size([2]) + iarange_shape + p2.shape
     assert x3.shape == torch.Size([3, 1]) + iarange_shape + p3.shape
Exemple #22
0
def bernoulli_normal_model():
    bern_0 = pyro.sample('bern_0', dist.Bernoulli(torch.zeros(1) * 1e-2))
    loc = torch.ones(1) if bern_0.item() else -torch.ones(1)
    normal_0 = torch.ones(1)
    pyro.sample('normal_0', dist.Normal(loc, torch.ones(1) * 1e-2),
                obs=normal_0)
    return [bern_0, normal_0]
Exemple #23
0
 def guide():
     alpha_q_log = pyro.param("alpha_q_log",
                              Variable(self.log_alpha_n.data + 0.17, requires_grad=True))
     beta_q_log = pyro.param("beta_q_log",
                             Variable(self.log_beta_n.data - 0.143, requires_grad=True))
     alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log)
     pyro.sample("p_latent", dist.beta, alpha_q, beta_q)
     pyro.map_data("aaa", self.data, lambda i, x: None, batch_size=self.batch_size)
Exemple #24
0
 def guide(self, x):
     # register PyTorch module `encoder` with Pyro
     pyro.module("encoder", self.encoder)
     with pyro.iarange("data", x.size(0)):
         # use the encoder to get the parameters used to define q(z|x)
         z_loc, z_scale = self.encoder.forward(x)
         # sample the latent code z
         pyro.sample("latent", dist.Normal(z_loc, z_scale).independent(1))
Exemple #25
0
 def guide():
     q = pyro.param("q")
     with pyro.iarange("particles", num_particles):
         pyro.sample("x", dist.Bernoulli(q).expand_by([num_particles]),
                     infer={"enumerate": enumerate1})
         for i in pyro.irange("irange", irange_dim):
             pyro.sample("y_{}".format(i), dist.Bernoulli(q).expand_by([num_particles]),
                         infer={"enumerate": enumerate2})
Exemple #26
0
 def model_sample(self, batch_size=1):
     # sample the handwriting style from the constant prior distribution
     prior_mu = Variable(torch.zeros([batch_size, self.z_dim]))
     prior_sigma = Variable(torch.ones([batch_size, self.z_dim]))
     zs = pyro.sample("z", dist.normal, prior_mu, prior_sigma)
     mu = self.decoder.forward(zs)
     xs = pyro.sample("sample", dist.bernoulli, mu)
     return xs, mu
Exemple #27
0
def irange_model(subsample_size):
    loc = torch.zeros(20)
    scale = torch.ones(20)
    result = []
    for i in pyro.irange('irange', 20, subsample_size):
        pyro.sample("x_{}".format(i), dist.Normal(loc[i], scale[i]))
        result.append(i)
    return result
Exemple #28
0
def model(data):
	
	idk = torch.tensor(4.0) # where is my neural network?
	idkb = torch.tensor(4.0)
	genelambdas = dist.Gamma(idka, idkb, batch_size = 19795)
	for celltype in range(data.size(0)): # this one's 56 right?
		with iarange('observe_{}'.format(celltype)):
			pyro.sample('indiv', dist.Poisson(genelambdas), obs=data[celltype])
Exemple #29
0
 def guide():
     q = pyro.param("q")
     with pyro.iarange("particles", num_particles):
         pyro.sample("y", dist.Bernoulli(q).expand_by([num_particles]),
                     infer={"enumerate": enumerate1})
         with pyro.iarange("iarange", iarange_dim):
             pyro.sample("z", dist.Bernoulli(q).expand_by([iarange_dim, num_particles]),
                         infer={"enumerate": enumerate2})
 def obs_inner(i, _i, _x):
     for k in range(n_superfluous_top):
         pyro.sample("z_%d_%d" % (i, k),
                     dist.Normal(ng_zeros(4 - i, 1), ng_ones(4 - i, 1), reparameterized=False))
     pyro.observe("obs_%d" % i, dist.normal, _x, mu_latent, torch.pow(self.lam, -0.5))
     for k in range(n_superfluous_top, n_superfluous_top + n_superfluous_bottom):
         pyro.sample("z_%d_%d" % (i, k),
                     dist.Normal(ng_zeros(4 - i, 1), ng_ones(4 - i, 1), reparameterized=False))
Exemple #31
0
 def guide():
     pyro.module("mymodule", pt_guide)
     mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log)
     sigma = torch.pow(tau_q, -0.5)
     pyro.sample("mu_latent", dist.Normal(mu_q, sigma, reparameterized=reparameterized))
Exemple #32
0
def guide(x):
    x = torch.reshape(x, [320, 4096])
    
    with pyro.plate("w_top_plate", 4000):
        #============ sample_ws
        alpha_w_q =\
            pyro.param("log_alpha_w_q_top",
                       alpha_init * torch.ones(4000) +
                       sigma_init * torch.randn(4000))
        mean_w_q =\
            pyro.param("log_mean_w_q_top",
                       mean_init * torch.ones(4000) +
                       sigma_init * torch.randn(4000)) 
        alpha_w_q = softplus(alpha_w_q)
        mean_w_q  = softplus(mean_w_q)
        pyro.sample("w_top", Gamma(alpha_w_q, alpha_w_q / mean_w_q))
        #============ sample_ws

    with pyro.plate("w_mid_plate", 600):
        #============ sample_ws
        alpha_w_q =\
            pyro.param("log_alpha_w_q_mid",
                       alpha_init * torch.ones(600) +
                       sigma_init * torch.randn(600)) 
        mean_w_q =\
            pyro.param("log_mean_w_q_mid",
                       mean_init * torch.ones(600) +
                       sigma_init * torch.randn(600)) 
        alpha_w_q = softplus(alpha_w_q)
        mean_w_q  = softplus(mean_w_q)
        pyro.sample("w_mid", Gamma(alpha_w_q, alpha_w_q / mean_w_q))
        #============ sample_ws

    with pyro.plate("w_bottom_plate", 61440):
        #============ sample_ws
        alpha_w_q =\
            pyro.param("log_alpha_w_q_bottom",
                       alpha_init * torch.ones(61440) +
                       sigma_init * torch.randn(61440)) 
        mean_w_q =\
            pyro.param("log_mean_w_q_bottom",
                       mean_init * torch.ones(61440) +
                       sigma_init * torch.randn(61440)) 
        alpha_w_q = softplus(alpha_w_q)
        mean_w_q  = softplus(mean_w_q)
        pyro.sample("w_bottom", Gamma(alpha_w_q, alpha_w_q / mean_w_q))
        #============ sample_ws

    with pyro.plate("data", 320):
        #============ sample_zs
        alpha_z_q =\
            pyro.param("log_alpha_z_q_top",
                       alpha_init * torch.ones(320, 100) +
                       sigma_init * torch.randn(320, 100)) 
        mean_z_q =\
            pyro.param("log_mean_z_q_top",
                       mean_init * torch.ones(320, 100) +
                       sigma_init * torch.randn(320, 100))
        alpha_z_q = softplus(alpha_z_q)
        mean_z_q  = softplus(mean_z_q)
        pyro.sample("z_top", Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1))
        #============ sample_zs
        #============ sample_zs
        alpha_z_q =\
            pyro.param("log_alpha_z_q_mid",
                       alpha_init * torch.ones(320, 40) +
                       sigma_init * torch.randn(320, 40)) 
        mean_z_q =\
            pyro.param("log_mean_z_q_mid",
                       mean_init * torch.ones(320, 40) +
                       sigma_init * torch.randn(320, 40))
        alpha_z_q = softplus(alpha_z_q)
        mean_z_q  = softplus(mean_z_q)
        pyro.sample("z_mid", Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1))
        #============ sample_zs
        #============ sample_zs
        alpha_z_q =\
            pyro.param("log_alpha_z_q_bottom",
                       alpha_init * torch.ones(320, 15) +
                       sigma_init * torch.randn(320, 15)) 
        mean_z_q =\
            pyro.param("log_mean_z_q_bottom",
                       mean_init * torch.ones(320, 15) +
                       sigma_init * torch.randn(320, 15))
        alpha_z_q = softplus(alpha_z_q)
        mean_z_q  = softplus(mean_z_q)
        pyro.sample("z_bottom", Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1))
Exemple #33
0
    def model(self):

        # If the Kernel IS NOT random, we declare the kernel within the model
        # if (not self.random_kernel):
        #     self.kernel = pydmn.kernels.RBF()

        # Covariance matrix of observed times entailed by our kernel
        # Kff = self.kernel(self.Y_time.reshape(-1,1))
        # Kff.view(-1)[::self.T_net + 1] += self.jitter  # add jitter to the diagonal
        # Lff = Kff.cholesky() # cholesky lower triangular
        Lff = self.Lff_ini

        ## Sampling system-wide connectivity and average weights ##
        with pyro.plate('gp_system_all', self.K_net*self.n_w ):
            # Mean function of the GPs
            gp_system_mean = pyro.sample( "gp_system_mean",
                                    dist.Normal( torch.zeros( (self.K_net*self.n_w) ),
                                                torch.tensor([0.1]) ) )
            # Demeaned GPs
            gp_system_demean = pyro.sample( "gp_system_demean",
                                            dist.MultivariateNormal( torch.zeros( (self.K_net*self.n_w, self.T_net) ),
                                                                        scale_tril=Lff ) )
        gp_system_mean = gp_system_mean.reshape(self.K_net,self.n_w)
        gp_system_demean = gp_system_demean.reshape(self.K_net, self.n_w, self.T_net)
        # Latent systemic evolution
        gp_system = gp_system_mean.expand(self.T_net, self.K_net, self.n_w).permute(1,2,0) + gp_system_demean

        ## Sampling latent coordinates ##
        if self.coord:
            with pyro.plate('gp_coord_all', self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w ):
                # Mean function of the GPs
                gp_coord_mean = pyro.sample( "gp_coord_mean",
                                        dist.Normal( torch.zeros( (self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w) ),
                                                    torch.tensor([0.1]) ) )
                # Demeaned GPs
                gp_coord_demean = pyro.sample( "gp_coord_demean",
                                                dist.MultivariateNormal( torch.zeros( (self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w, self.T_net) ),
                                                                            scale_tril=Lff ) )

            gp_coord_mean = gp_coord_mean.reshape(self.V_net,self.H_dim,self.K_net,self.n_dir,self.n_w)
            gp_coord_demean = gp_coord_demean.reshape(self.V_net, self.H_dim, self.K_net, self.n_dir, self.n_w, self.T_net)
            # Latent coordinates
            gp_coord = gp_coord_mean.expand(self.T_net, self.V_net, self.H_dim, self.K_net, self.n_dir, self.n_w).permute(1,2,3,4,5,0) + gp_coord_demean

        ## Sampling Sociability and Popularity terms ##
        if self.socpop:
            with pyro.plate('gp_socpop_all', self.V_net*self.K_net*self.n_dir*self.n_w ):
                # Mean function of the GPs
                gp_socpop_mean = pyro.sample( "gp_socpop_mean",
                                        dist.Normal( torch.zeros( (self.V_net*self.K_net*self.n_dir*self.n_w) ),
                                                    torch.tensor([0.1]) ) )
                # Demeaned GPs
                gp_socpop_demean = pyro.sample( "gp_socpop_demean",
                                                dist.MultivariateNormal( torch.zeros( (self.V_net*self.K_net*self.n_dir*self.n_w, self.T_net) ),
                                                                            scale_tril=Lff ) )

            gp_socpop_mean = gp_socpop_mean.reshape(self.V_net,self.K_net,self.n_dir,self.n_w)
            gp_socpop_demean = gp_socpop_demean.reshape(self.V_net, self.K_net, self.n_dir, self.n_w, self.T_net)
            # Latent coordinates
            gp_socpop = gp_socpop_mean.expand(self.T_net, self.V_net, self.K_net, self.n_dir, self.n_w).permute(1,2,3,4,0) + gp_socpop_demean

        ### Linear Predictor ###
        # Systemic component
        Y_linpred = gp_system.expand(self.V_net, self.V_net, self.K_net, self.n_w, self.T_net).permute(0,1,4,2,3)
        # Distance between agents
        if self.coord:
            Y_linpred = Y_linpred + torch.einsum('uhkwt,vhkwt->uvtkw', gp_coord[:,:,:,0,:,:], gp_coord[:,:,:,self.n_dir-1,:,:])
        # Sociability and Popularity effects
        if self.socpop:
            gp_soc = gp_socpop[:,:,0,:,:].expand(self.V_net, self.V_net, self.K_net, self.n_w, self.T_net).transpose(0,1)
            gp_pop = gp_socpop[:,:,self.n_dir-1,:,:].expand(self.V_net, self.V_net, self.K_net, self.n_w,self.T_net)
            Y_linpred = Y_linpred + gp_soc.permute(0,1,4,2,3) + gp_pop.permute(0,1,4,2,3)

        ### Link propensity (probability of occur) ###
        Y_link_prob = torch.sigmoid(Y_linpred[:,:,:,:,0])
        Y_link_prob_valid = Y_link_prob.flatten()[self.Y_valid_id.flatten()==1]

        with pyro.plate( "data", Y_link_prob_valid.shape[0]):
            pyro.sample( "obs", dist.Bernoulli(Y_link_prob_valid), obs=self.Y_link.flatten()[self.cond_Y_link] )

        ### Link expected weight (weight given occurrence) ###
        if self.weighted:
            with pyro.plate( "sigma_k_ind", self.K_net):
                sigma_k = pyro.sample( 'sigma_k', dist.InverseGamma( self.sigma_k_prior_param[0].expand(self.K_net), self.sigma_k_prior_param[1].expand(self.K_net) ) )
            Y_link_SDw = sigma_k.expand(self.V_net,self.V_net,self.T_net,self.K_net)

            Y_link_Ew = Y_linpred[:,:,:,:,1]
            # cond_Y_w: condition of being positive and valid weights (defined in set_data())
            Y_link_Ew_valid = Y_link_Ew.flatten()[self.cond_Y_w]
            Y_link_SDw_valid = Y_link_SDw.flatten()[self.cond_Y_w]
            with pyro.plate( "data_w", Y_link_Ew_valid.shape[0] ):
                pyro.sample( "obs_w", dist.Normal(Y_link_Ew_valid,Y_link_SDw_valid), obs=self.Y.flatten()[self.cond_Y_w] )
Exemple #34
0
 def model():
     p = pyro.param("p", Variable(torch.Tensor([0.05])))
     ps = pyro.param("ps", Variable(torch.Tensor([0.1, 0.2, 0.3, 0.4])))
     x = pyro.sample("x", dist.Bernoulli(p))
     y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
     return dict(x=x, y=y)
Exemple #35
0
 def guide():
     p = pyro.param(
         "p", Variable(torch.Tensor([0.0, 0.5, 1.0]), requires_grad=True))
     pyro.sample("z", dist.Bernoulli(p))
Exemple #36
0
def init_params(data):
    params = {}
    params["beta"] = init_vector("beta", dims=(2)) # vector
    params["sigma"] = pyro.sample("sigma", dist.Uniform(0., 1000.))

    return params
Exemple #37
0
def generate_number_grammar(input_symbols, grammar=None):
    zeroRule = False
    connectingWordsProb = 0.2
    ZeroProb = 0.2
    exceptionProb = 0.3
    tenThousandWordProb = 0.3

    input_symbol_options = copy.deepcopy(input_symbols)
    #random.shuffle(input_symbol_options)

    rules, intRules, input_symbol_options, oneWord = generateOneToTen(
        input_symbol_options)

    for base in [10000, 1000, 100, 10]:

        if base == 10 and tp == 'irregular':
            tp = 'irregular'
        else:
            if base in [100, 10]:
                tp = selectFromList(['regular', 'irregular'],
                                    f"regularity_{base}",
                                    obs=None)
            else:
                tp = 'regular'

        # IMPORTANT: if 100s are irregular 10s cant be regular ... probably fine
        # TODO teens and twenties?
        # maybe teens are irregular if 10s are?

        if tp == 'irregular':
            for i in range(1, 10):
                num = str(i * base)
                word, input_symbol_options = popFromList(
                    input_symbol_options,
                    name=f"irreg_word_{base}_{i}",
                    obs=None)
                rules.append(Rule(word, num))
                intRules.append([str(num), '->', word])
                #french situation??

            #do irregular teens half the time:
            if base == 10 and pyro.sample('irreg_teens',
                                          pyro.distributions.Bernoulli(0.5)):
                for i in range(11, 20):
                    num = str(i)
                    word, input_symbol_options = popFromList(
                        input_symbol_options, name=f"teen_{i}", obs=None)
                    rules.append(Rule(word, num))
                    intRules.append([str(num), '->', word])

            intRules.append([
                '>' + str(base), '->', '[x - x%' + str(base) + ']',
                '[x%' + str(base) + ']'
            ])

        elif tp == 'regular':

            #we have a word for ten thousand sometimes
            if base == 10000:
                if not pyro.sample(
                        f'ten_thousand_word',
                        pyro.distributions.Bernoulli(tenThousandWordProb),
                        obs=None):
                    base = 1000000

            baseWord, input_symbol_options = popFromList(input_symbol_options,
                                                         name=f"word_{base}",
                                                         obs=None)

            rule = Rule('x1 ' + baseWord + ' y1',
                        '[x1]*' + str(base) + ' [y1]')
            rules.append(rule)
            intRules.append([
                '>' + str(base), '->', '[x//' + str(base) + ']', baseWord,
                '[x%' + str(base) + ']'
            ])

            if pyro.sample(f'one_exception_{base}',
                           pyro.distributions.Bernoulli(exceptionProb),
                           obs=None):  #for now, only one exception
                oneException = True
                if pyro.sample(f'one_exception_change_{base}',
                               pyro.distributions.Bernoulli(0.3),
                               obs=None):  #for now, only one exception
                    oneExceptionWord, input_symbol_options = popFromList(
                        input_symbol_options, name=f"oneWord_{base}", obs=None)
                    exceptionRule = Rule(' '.join([oneExceptionWord, 'y1']),
                                         str(base) + '* 1' + ' [y1]')
                    intRule = [
                        '//' + str(base) + '==' + str(1), '->',
                        oneExceptionWord, '[x%' + str(base) + ']'
                    ]
                else:
                    exceptionRule = Rule(' '.join([baseWord, 'y1']),
                                         str(base) + '* 1' + ' [y1]')
                    intRule = [
                        '//' + str(base) + '==' + str(1), '->', baseWord,
                        '[x%' + str(base) + ']'
                    ]

                rules.insert(-1, exceptionRule)
                intRules.insert(-1, intRule)

                explicitExceptionRule = Rule(baseWord, str(base))
            else:
                explicitExceptionRule = Rule(' '.join([oneWord, baseWord]),
                                             str(base))
            rules.insert(0, explicitExceptionRule)

            if pyro.sample(f'exception_{base}',
                           pyro.distributions.Bernoulli(exceptionProb),
                           obs=None):  #for now, only one exception
                exceptionNum = selectFromList(list(range(2, 10)),
                                              name=f"exception_num_{base}",
                                              obs=None)
                exceptionWord, input_symbol_options = popFromList(
                    input_symbol_options,
                    name=f"exception_name_{base}",
                    obs=None)
                exceptionRule = Rule(
                    ' '.join([exceptionWord, baseWord, 'y1']),
                    str(base) + '* ' + str(exceptionNum) + ' [y1]')
                intRule = [
                    '//' + str(base) + '==' + str(exceptionNum), '->',
                    exceptionWord, baseWord, '[x%' + str(base) + ']'
                ]
                #TODO other direction

                #TODO
                rules.insert(-1, exceptionRule)
                intRules.insert(-1, intRule)
        else:
            assert False

    if pyro.sample(f'connecting_word',
                   pyro.distributions.Bernoulli(connectingWordsProb),
                   obs=None):
        connectingWord, input_symbol_options = popFromList(
            input_symbol_options, name=f"connecting_word_val", obs=None)
        rules.append(Rule('u1 ' + connectingWord + ' x1', '[u1] [x1]'))
        intRules.append(['>10', '->', '[x - x%10]', connectingWord, '[x%10]'])
        #assert False, "need to figure out where to put connectingWord"

    if pyro.sample(f'zero', pyro.distributions.Bernoulli(ZeroProb), obs=None):
        zeroWord, input_symbol_options = popFromList(input_symbol_options,
                                                     name=f"zero_word",
                                                     obs=None)
        rules.append(Rule(zeroWord, '0'))
        intRules.append(['0', '->', zeroWord])
        zeroRule = True

    concatRule = Rule('u1 x1', '[u1] [x1]')
    rules.append(concatRule)

    #can shuffle those rules, but it doesn't matter, do that at example sampling time
    return NumberGrammar(rules, input_symbols), IntGrammar(
        intRules, zeroRule)  #makeIntG(intRules, intG)
Exemple #38
0
def init_vector(name, dims=None):
    return pyro.sample(
        name,
        dist.Normal(torch.zeros(dims), 0.2 * torch.ones(dims)).to_event(1))
Exemple #39
0
 def forward(self, x, y=None):
     y_pr = self.seq(x)
     if y != None:
         with pyro.plate("data", y.shape[0]):
             pyro.sample("obs", dist.Normal(y, self.target_std).to_event(1), obs=y_pr)
     return y_pr.detach()
Exemple #40
0
 def guide(self, data):
     encoder = pyro.module('encoder', self.vae_encoder)
     with pyro.plate('data', data.size(0)):
         z_mean, z_var = encoder.forward(data)
         pyro.sample('latent', Normal(z_mean, z_var.sqrt()).to_event(1))
Exemple #41
0
def model(game, observer, action):
    if game.turn == observer:
        return None
    known_cards = get_known_cards(game, observer)
    if action:
        known_cards.update(action.cards)
    unknown_cards = list(set(game.unused_cards) - known_cards)

    idx_to_id = idx_2_id(observer)
    id_to_idx = id_2_idx(observer)

    num_cards_in_hand = {
        i: len(set(game.players[idx_to_id[i]].hand) - known_cards)
        for i in idx_to_id
    }

    probs = []
    for card in unknown_cards:
        """
        FIXME: prior distribution of cards?
        """
        theta = [torch.tensor(global_card_dist[card][i]) for i in idx_to_id]
        player_probs = pyro.sample('{}_probs'.format(card),
                                   dist.Dirichlet(torch.stack(theta)))
        normalized_player_probs = player_probs  #/ torch.sum(player_probs)
        probs.append(normalized_player_probs)

    probs = torch.stack(probs)
    hands = {i: list() for i in idx_to_id}
    card_probs = {i: list() for i in idx_to_id}
    for i, card in random.sample(tuple(enumerate(unknown_cards)),
                                 len(unknown_cards)):
        assigned = False
        while not assigned:
            player = torch.distributions.Categorical(probs=probs[i]).sample()
            #player = pyro.sample('{}_locs'.format(card), dist.Categorical(probs=probs[i]))
            if len(hands[int(player)]) < num_cards_in_hand[int(player)]:
                hands[int(player)].append(card)
                card_probs[int(player)].append(probs[i][int(player)])
                assigned = True
    """
    for i in idx_to_id:
        pyro.sample(
            '{}_card_assignment'.format(i),
            dist.Bernoulli(probs=torch.prod(torch.tensor(card_probs[i]))),
            obs=torch.tensor(1.)
        )
    """

    ai_player = GreedyPlayer(game, game.turn)
    ai_player.hand = hands[id_to_idx[game.turn]]
    if action:
        ai_player.hand += action.cards
    #print("===============card===============")
    #for card in ai_player.hand:
    #    print(card)
    #print("==================================")
    actions, action_probs = tuple(
        [list(t) for t in zip(*ai_player.action_probs())])
    #print(actions, action_probs)
    #print(actions, action_probs)
    #print([str(a) for a in actions])
    #print(action)
    action_dist = dist.Categorical(probs=torch.tensor(action_probs))
    pyro.sample('action', action_dist, obs=torch.tensor(actions.index(action)))
    return hands
Exemple #42
0
def normal_product(loc, scale):
    z1 = pyro.sample("z1", pyro.distributions.Normal(loc, scale))
    z2 = pyro.sample("z2", pyro.distributions.Normal(loc, scale))
    y = z1 * z2
    return y
Exemple #43
0
 def model_1():
     a = pyro.sample("a", dist.Normal(0, 1))
     pyro.sample("b", dist.Normal(a, 1), obs=torch.tensor(0.0))
Exemple #44
0
 def model(data):
     w = pyro.sample("w", dist.Normal(0, 1))
     with pyro.plate("p", len(data)):
         x = pyro.sample("x", dist.Normal(0, 1))
         y = pyro.sample("y", dist.Normal(0, 1))
         pyro.sample("z", dist.Normal(w + x + y, 1), obs=data)
Exemple #45
0
 def model():
     p = Variable(torch.Tensor([0.0, 0.5, 1.0]))
     pyro.sample("z", dist.Bernoulli(p))
Exemple #46
0
 def model(data):
     with pyro.plate("p", len(data)):
         x = pyro.sample("x", dist.Normal(0, 1))
         y = pyro.sample("y", dist.Normal(0, 1))
     pyro.sample("z", dist.Normal(x.sum(), y.sum().exp()), obs=data.sum())
Exemple #47
0
    def guide(self):

        # Posterior Covariance of the GP
        # if self.random_kernel:
        #     self.kernel_param = pyro.param("kernel_param", 50*torch.ones((2,2)), constraint=constraints.positive)
        #     pyro.sample( "kernel.lengthscale", dist.InverseGamma( self.kernel_param[0,0], self.kernel_param[0,1] ) )
        #     pyro.sample( "kernel.variance", dist.InverseGamma( self.kernel_param[1,0], self.kernel_param[1,1] ) )

        # Sampling Systemic components #
        self.gp_system_mean_loc = pyro.param("gp_system_mean_loc", self.gp_system_mean_ini )
        self.gp_system_mean_scale = pyro.param("gp_system_mean_scale", 0.1*torch.ones((self.K_net,self.n_w)), constraint=constraints.positive)
        self.gp_system_demean = pyro.param( f"gp_system_demean_loc", self.gp_system_demean_ini )
        # Posterior Covariance of the GP
        self.gp_system_cov_tril = pyro.param( "gp_system_cov_tril", self.Lff_ini.expand(self.K_net,self.n_w,self.T_net,self.T_net),
                                        constraint=constraints.lower_cholesky )
        with pyro.plate('gp_system_all', self.K_net*self.n_w ):
            # Posterior GP (mean function params) #
            pyro.sample( "gp_system_mean", dist.Normal( self.gp_system_mean_loc.reshape(self.K_net*self.n_w),
                                                self.gp_system_mean_scale.reshape(self.K_net*self.n_w) ) )
            # Posterior GP (demeaned) #
            pyro.sample( f"gp_system_demean",
                                    dist.MultivariateNormal( self.gp_system_demean.reshape(self.K_net*self.n_w , self.T_net),
                                                            scale_tril=self.gp_system_cov_tril.reshape(self.K_net*self.n_w , self.T_net, self.T_net) ) )

        # Sampling coordinates #
        if self.coord:
            self.gp_coord_mean_loc = pyro.param("gp_coord_mean_loc", self.gp_coord_mean_ini )
            self.gp_coord_mean_scale = pyro.param("gp_coord_mean_scale", 0.1*torch.ones((self.V_net,self.H_dim,self.K_net,self.n_dir,self.n_w)), constraint=constraints.positive)
            self.gp_coord_demean = pyro.param( f"gp_coord_demean_loc", self.gp_coord_demean_ini )
            # Posterior Covariance of the GP
            self.gp_coord_cov_tril = pyro.param( "gp_coord_cov_tril", self.Lff_ini.expand(self.V_net,self.H_dim,self.K_net,self.n_dir,self.n_w,self.T_net,self.T_net),
                                            constraint=constraints.lower_cholesky )
            with pyro.plate('gp_coord_all', self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w ):
                # Posterior GP (mean function params) #
                pyro.sample( "gp_coord_mean", dist.Normal( self.gp_coord_mean_loc.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w),
                                                    self.gp_coord_mean_scale.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w) ) )
                # Posterior GP (demeaned) #
                pyro.sample( f"gp_coord_demean",
                                        dist.MultivariateNormal( self.gp_coord_demean.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w , self.T_net),
                                                                scale_tril=self.gp_coord_cov_tril.reshape(self.V_net*self.H_dim*self.K_net*self.n_dir*self.n_w , self.T_net, self.T_net) ) )

        # Sampling sociability and popularity #
        if self.socpop:
            self.gp_socpop_mean_loc = pyro.param("gp_socpop_mean_loc", self.gp_socpop_mean_ini )
            self.gp_socpop_mean_scale = pyro.param("gp_socpop_mean_scale", 0.1*torch.ones((self.V_net,self.K_net,self.n_dir,self.n_w)), constraint=constraints.positive)
            self.gp_socpop_demean = pyro.param( f"gp_socpop_demean_loc", self.gp_socpop_demean_ini )
            # Posterior Covariance of the GP
            self.gp_socpop_cov_tril = pyro.param( "gp_socpop_cov_tril", self.Lff_ini.expand(self.V_net,self.K_net,self.n_dir,self.n_w,self.T_net,self.T_net),
                                            constraint=constraints.lower_cholesky )
            with pyro.plate('gp_socpop_all', self.V_net*self.K_net*self.n_dir*self.n_w ):
                # Posterior GP (mean function params) #
                pyro.sample( "gp_socpop_mean", dist.Normal( self.gp_socpop_mean_loc.reshape(self.V_net*self.K_net*self.n_dir*self.n_w),
                                                    self.gp_socpop_mean_scale.reshape(self.V_net*self.K_net*self.n_dir*self.n_w) ) )
                # Posterior GP (demeaned) #
                pyro.sample( f"gp_socpop_demean",
                                        dist.MultivariateNormal( self.gp_socpop_demean.reshape(self.V_net*self.K_net*self.n_dir*self.n_w , self.T_net),
                                                                scale_tril=self.gp_socpop_cov_tril.reshape(self.V_net*self.K_net*self.n_dir*self.n_w , self.T_net, self.T_net) ) )

        # pyro.sample( "kernel.variance", dist.InverseGamma( self.kernel_param[1,0], self.kernel_param[1,1] ) )

        # Sampling variance of weights
        if self.weighted:
            self.sigma_k_post_loc = pyro.param("sigma_k_post_loc", torch.ones([1]), constraint=constraints.positive )
            self.sigma_k_post_scale = pyro.param("sigma_k_post_scale", torch.ones([1]), constraint=constraints.positive )
            with pyro.plate( "sigma_k_ind", self.K_net):
                sigma_k = pyro.sample( 'sigma_k', dist.InverseGamma( self.sigma_k_post_loc, self.sigma_k_post_scale ) )
Exemple #48
0
 def guide():
     p = pyro.param("p", Variable(torch.ones(1), requires_grad=True))
     pyro.sample("mu_q", dist.normal, ng_zeros(1), p)
     pyro.sample("mu_q_2", dist.normal, ng_zeros(1), p)
Exemple #49
0
def wrapped_model(x_data, y_data):
    pyro.sample("prediction", Delta(model(x_data, y_data)))
Exemple #50
0
 def model_obs_dup():
     pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
     pyro.observe("mu_q", dist.normal, ng_zeros(1), ng_ones(1), ng_zeros(1))
Exemple #51
0
 def model():
     a = pyro.sample("a", dist.Dirichlet(torch.ones(3)))
     b = pyro.sample("b", dist.Categorical(a))
     c = pyro.sample("c", dist.Normal(torch.zeros(3), 1).to_event(1))
     d = pyro.sample("d", dist.Poisson(c[b].exp()))
     pyro.sample("e", dist.Normal(d, 1), obs=torch.ones(()))
Exemple #52
0
def make_normal_normal():
    mu_latent = pyro.sample("mu_latent", pyro.distributions.Normal(0, 1))
    fn = lambda scale: normal_product(mu_latent, scale)
    return fn
Exemple #53
0
 def model():
     pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
Exemple #54
0
 def model():
     a = pyro.sample("a", dist.Normal(0, 1))
     pyro.factor("b", torch.tensor(0.0))
     pyro.factor("c", a)
Exemple #55
0
 def model_dup():
     pyro.param("mu_q", Variable(torch.ones(1), requires_grad=True))
     pyro.sample("mu_q", dist.normal, ng_zeros(1), ng_ones(1))
Exemple #56
0
 def model_3():
     with pyro.plate("p", 5):
         a = pyro.sample("a", dist.Normal(0, 1))
     pyro.sample("b", dist.Normal(a.sum(), 1), obs=torch.tensor(0.0))
Exemple #57
0
 def model():
     lambda_latent = pyro.sample("lambda_latent", dist.gamma, self.alpha0, self.beta0)
     pyro.observe("obs0", dist.exponential, self.data[0], lambda_latent)
     pyro.observe("obs1", dist.exponential, self.data[1], lambda_latent)
     return lambda_latent
Exemple #58
0
 def model_2():
     a = pyro.sample("a", dist.Normal(0, 1))
     b = pyro.sample("b", dist.LogNormal(0, 1))
     c = pyro.sample("c", dist.Normal(a, b))
     pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.0))
Exemple #59
0
def init_params(data):
    params = {}
    params["sigma_a1"] = pyro.sample("sigma_a1", dist.HalfCauchy(2.5))
    params["sigma_a2"] = pyro.sample("sigma_a2", dist.HalfCauchy(2.5))
    params["sigma_y"] = pyro.sample("sigma_y", dist.HalfCauchy(2.5))
    return params
Exemple #60
0
def ice_cream_sales():
    cloudy, temp = weather()
    expected_sales = 200. if cloudy == 'sunny' and temp > 80.0 else 50.
    ice_cream = pyro.sample('ice_cream', pyro.distributions.Normal(expected_sales, 10.0))
    return ice_cream