Esempio n. 1
0
def model(doc_word_data=None, category_data=None, args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", args.num_topics):
        # topic_weights does not seem to come from the usual LDA plate notation, but seems to give an indication of
        # the importance of topics. It might be from the amortized LDA paper.
        topic_weights = pyro.sample("topic_weights",
                                    dist.Gamma(1. / args.num_topics, 1.))
        topic_words = pyro.sample(
            "topic_words",
            dist.Dirichlet(torch.ones(args.num_words) / args.num_words))

    with pyro.plate("categories", args.num_categories):
        category_weights = pyro.sample(
            "category_weights", dist.Gamma(1. / args.num_categories, 1.))
        # TODO category weights might not be necessary in our model
        category_topics = pyro.sample("category_topics",
                                      dist.Dirichlet(topic_weights))

    doc_category_list = []
    doc_word_list = []

    # Locals.
    for index, doc in enumerate(pyro.plate("documents", args.num_docs)):
        if doc_word_data is not None:
            cur_doc_word_data = doc_word_data[doc]
        else:
            cur_doc_word_data = None

        if category_data is not None:
            cur_category_data = category_data[doc]
        else:
            cur_category_data = None

        doc_category_list.append(
            pyro.sample("doc_categories_{}".format(doc),
                        dist.Categorical(category_weights),
                        obs=cur_category_data))

        with pyro.plate("words_{}".format(doc), args.num_words_per_doc[doc]):
            word_topics = pyro.sample(
                "word_topics_{}".format(doc),
                dist.Categorical(category_topics[int(
                    doc_category_list[index].item())]))
            # TODO Enum parallel/sequential optimizing?

            doc_word_list.append(
                pyro.sample("doc_words_{}".format(doc),
                            dist.Categorical(topic_words[word_topics]),
                            obs=cur_doc_word_data))

    results = {
        "topic_weights": topic_weights,
        "topic_words": topic_words,
        "doc_word_data": doc_word_list,
        "category_weights": category_weights,
        "category_topics": category_topics,
        "doc_category_data": doc_category_list
    }

    return results
Esempio n. 2
0
def dp_sb_gmm(y, num_components):
    # Cosntants
    N = y.shape[0]
    K = num_components

    # Priors
    # NOTE: In pyro, the Gamma distribution is parameterized with shape and rate.
    # Hence, Gamma(shape, rate) => mean = shape/rate
    alpha = pyro.sample('alpha', dist.Gamma(1, 10))

    with pyro.plate('mixture_weights', K - 1):
        v = pyro.sample('v', dist.Beta(1, alpha, K - 1))

    eta = stickbreak(v)

    with pyro.plate('components', K):
        mu = pyro.sample('mu', dist.Normal(0., 3.))
        sigma = pyro.sample('sigma', dist.Gamma(1, 10))

    with pyro.plate('data', N):
        # Mixture version.
        pyro.sample('obs',
                    dist.MixtureSameFamily(dist.Categorical(eta),
                                           dist.Normal(mu, sigma)),
                    obs=y)
def logistic_regression_model(x, y, x_a0, x_a1, y_a0_vs_a1, M, beta, tau):
    # beta is preference observation "inverse temperature"
    if isinstance(tau, tuple):
        tau_ = pyro.sample("tau", dist.Gamma(tau[0], tau[1]))
    else:
        tau_ = tau
    if isinstance(beta, tuple):
        beta_ = pyro.sample("beta", dist.Gamma(beta[0], beta[1]))
    else:
        beta_ = beta

    w_ = pyro.sample(
        "w",
        dist.Normal(torch.zeros(M, dtype=torch.double), tau_).independent(1))
    if y.size()[0] > 0:
        # direct observations
        probs = x @ w_
        pyro.sample("y", dist.Bernoulli(logits=probs), obs=y)
    if y_a0_vs_a1.size()[0] > 0:
        # pairwise preference observations
        prob_a0_vs_a1 = (x_a1 - x_a0) @ (beta_ * w_)

        pyro.sample("y_a0_vs_a1",
                    dist.Bernoulli(logits=prob_a0_vs_a1),
                    obs=y_a0_vs_a1)

    return w_, tau_, beta_
Esempio n. 4
0
def guide(y, BATCHES, SAMPLES):
    arg_1 = pyro.param('arg_1', torch.ones((amb(1))))
    arg_2 = pyro.param('arg_2',
                       torch.ones((amb(1))),
                       constraint=constraints.positive)
    theta = pyro.sample('theta'.format(''), dist.Normal(arg_1, arg_2))
    arg_3 = pyro.param('arg_3',
                       torch.ones((amb(1))),
                       constraint=constraints.positive)
    arg_4 = pyro.param('arg_4',
                       torch.ones((amb(1))),
                       constraint=constraints.positive)
    tau_between = pyro.sample('tau_between'.format(''),
                              dist.Gamma(arg_3, arg_4))
    arg_5 = pyro.param('arg_5',
                       torch.ones((amb(1))),
                       constraint=constraints.positive)
    arg_6 = pyro.param('arg_6',
                       torch.ones((amb(1))),
                       constraint=constraints.positive)
    tau_within = pyro.sample('tau_within'.format(''), dist.Gamma(arg_5, arg_6))
    arg_7 = pyro.param('arg_7',
                       torch.ones((amb(BATCHES))),
                       constraint=constraints.positive)
    arg_8 = pyro.param('arg_8',
                       torch.ones((amb(BATCHES))),
                       constraint=constraints.positive)
    with pyro.iarange('mu_prange'):
        mu = pyro.sample('mu'.format(''), dist.Gamma(arg_7, arg_8))
    for n in range(1, BATCHES + 1):
        pass
Esempio n. 5
0
 def model_hierarchical(self, models, items, obs):
     mu_b = pyro.sample(
         'mu_b',
         dist.Normal(torch.tensor(0., device=self.device),
                     torch.tensor(1.e6, device=self.device)))
     u_b = pyro.sample(
         'u_b',
         dist.Gamma(torch.tensor(1., device=self.device),
                    torch.tensor(1., device=self.device)))
     mu_theta = pyro.sample(
         'mu_theta',
         dist.Normal(torch.tensor(0., device=self.device),
                     torch.tensor(1.e6, device=self.device)))
     u_theta = pyro.sample(
         'u_theta',
         dist.Gamma(torch.tensor(1., device=self.device),
                    torch.tensor(1., device=self.device)))
     mu_a = pyro.sample(
         'mu_a',
         dist.Normal(torch.tensor(0., device=self.device),
                     torch.tensor(1.e6, device=self.device)))
     u_a = pyro.sample(
         'u_a',
         dist.Gamma(torch.tensor(1., device=self.device),
                    torch.tensor(1., device=self.device)))
     with pyro.plate('thetas', self.num_models, device=self.device):
         ability = pyro.sample('theta', dist.Normal(mu_theta, 1. / u_theta))
     with pyro.plate('bs', self.num_items, device=self.device):
         diff = pyro.sample('b', dist.Normal(mu_b, 1. / u_b))
         slope = pyro.sample('a', dist.Normal(mu_a, 1. / u_a))
     with pyro.plate('observe_data', obs.size(0)):
         pyro.sample("obs",
                     dist.Bernoulli(logits=slope[items] *
                                    (ability[models] - diff[items])),
                     obs=obs)
Esempio n. 6
0
def logistic_regression_mixture_obs_model_mla(x, y, x_arms, P, M, beta, tau,
                                              N_lookahead):
    # assumes x and y has been augmented with N_lookahead samples for the
    # lookahead which is always a direct observation

    # beta is preference observation "inverse temperature"
    if isinstance(tau, tuple):
        tau_ = pyro.sample("tau", dist.Gamma(tau[0], tau[1]))
    else:
        tau_ = tau
    if isinstance(beta, tuple):
        beta_ = pyro.sample("beta", dist.Gamma(beta[0], beta[1]))
    else:
        beta_ = beta

    w_ = pyro.sample(
        "w",
        dist.Normal(torch.zeros(M, dtype=torch.double), tau_).independent(1))
    a_ = pyro.sample(
        "alpha",
        dist.Beta(torch.tensor(1.0, dtype=torch.double),
                  torch.tensor(1.0, dtype=torch.double)),
    )
    logits = x @ w_

    N = y.numel()
    N_not_la = N - N_lookahead
    if N_not_la > 0:
        # multistep lookahead observations
        with torch.no_grad():
            x_a0 = x_arms.new_zeros(N_not_la, M)
            x_a1 = x_arms.new_zeros(N_not_la, M)

            inds = list(range(P[0].size()[0]))
            n_branching = len(inds) // 2
            inds_0 = inds[0:n_branching]
            inds_1 = inds[n_branching:len(inds)]

            logits_tmp = x_arms @ w_

            for i in range(len(P)):
                i0 = torch.argmax(P[i][inds_0, ] @ logits_tmp)
                i1 = torch.argmax(P[i][inds_1, ] @ logits_tmp)

                x_a0[i, ] = P[i][inds_0[i0], ] @ x_arms
                x_a1[i, ] = P[i][inds_1[i1], ] @ x_arms

        logits_a0_vs_a1 = (x_a1 - x_a0) @ (beta_ * w_)

        pyro.sample(
            "y",
            MixtureObsDistribution(a_, logits[0:N_not_la], logits_a0_vs_a1),
            obs=y[0:N_not_la],
        )
    if N_lookahead > 0:
        pyro.sample("y_lookahead",
                    dist.Bernoulli(logits=logits[N_not_la:N]),
                    obs=y[N_not_la:N])

    return w_, tau_, beta_
Esempio n. 7
0
def parametrized_guide(doc_word_data, category_data, args, batch_size=None):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param("topic_weights_posterior",
                                         lambda: torch.ones(args.num_topics),
                                         constraint=constraints.positive)
    topic_words_posterior = pyro.param(
        "topic_words_posterior",
        lambda: torch.ones(args.num_topics, args.num_words),
        constraint=constraints.greater_than(0.5))
    with pyro.plate("topics", args.num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    category_weights_posterior = pyro.param(
        "category_weights_posterior",
        lambda: torch.ones(args.num_categories),
        constraint=constraints.positive)
    category_topics_posterior = pyro.param(
        "category_topics_posterior",
        lambda: torch.ones(args.num_categories, args.num_topics),
        constraint=constraints.greater_than(0.5))
    with pyro.plate("categories", args.num_categories):
        pyro.sample("category_weights",
                    dist.Gamma(category_weights_posterior, 1.))
        pyro.sample("category_topics",
                    dist.Dirichlet(category_topics_posterior))

    doc_category_posterior = pyro.param("doc_category_posterior",
                                        lambda: torch.ones(args.num_topics),
                                        constraint=constraints.positive)
    with pyro.plate("documents", args.num_docs, batch_size) as ind:
        pyro.sample("doc_categories", dist.Categorical(doc_category_posterior))
Esempio n. 8
0
    def model(self, X, Y=None):
        '''
        
        '''
        E0_mean, E0_std, alpha_emax, beta_emax, alpha_H, beta_H, log10_ec50_mean, log10_ec50_std, alpha_obs, beta_obs = self.get_priors(
        )

        E0 = pyro.sample('E0', dist.Normal(E0_mean, E0_std))

        Emax = pyro.sample('Emax', dist.Beta(alpha_emax, beta_emax))

        H = pyro.sample('H', dist.Gamma(alpha_H, beta_H))

        EC50 = 10**pyro.sample('log_EC50',
                               dist.Normal(log10_ec50_mean, log10_ec50_std))

        obs_sigma = pyro.sample("obs_sigma", dist.Gamma(alpha_obs, beta_obs))

        obs_mean = E0 + (Emax - E0) / (1 + (EC50 / X)**H)

        with pyro.plate("data", X.shape[0]):
            obs = pyro.sample("obs",
                              dist.Normal(obs_mean.squeeze(-1), obs_sigma),
                              obs=Y)

        return obs_mean
Esempio n. 9
0
 def model_hierarchical(self, models, items, obs):
     """Initialize a 1PL model with hierarchical priors"""
     mu_b = pyro.sample(
         "mu_b",
         dist.Normal(torch.tensor(0.0, device=self.device),
                     torch.tensor(1.0e6, device=self.device)),
     )
     u_b = pyro.sample(
         "u_b",
         dist.Gamma(torch.tensor(1.0, device=self.device),
                    torch.tensor(1.0, device=self.device)),
     )
     mu_theta = pyro.sample(
         "mu_theta",
         dist.Normal(torch.tensor(0.0, device=self.device),
                     torch.tensor(1.0e6, device=self.device)),
     )
     u_theta = pyro.sample(
         "u_theta",
         dist.Gamma(torch.tensor(1.0, device=self.device),
                    torch.tensor(1.0, device=self.device)),
     )
     with pyro.plate("thetas", self.num_subjects, device=self.device):
         ability = pyro.sample("theta", dist.Normal(mu_theta,
                                                    1.0 / u_theta))
     with pyro.plate("bs", self.num_items, device=self.device):
         diff = pyro.sample("b", dist.Normal(mu_b, 1.0 / u_b))
     with pyro.plate("observe_data", obs.size(0)):
         pyro.sample("obs",
                     dist.Bernoulli(logits=ability[models] - diff[items]),
                     obs=obs)
Esempio n. 10
0
    def model(self, *args, **kwargs):
        I, N = self.params['data'].shape
        weights = pyro.sample('mixture_weights',
                              dist.Dirichlet(self.params['mixture']))
        with pyro.plate('segments', I):
            mu = pyro.sample(
                'gene_basal',
                dist.Gamma(self.params['theta_scale'],
                           self.params['theta_rate']))
            with pyro.plate('components', self.params['K']):
                cc = pyro.sample(
                    'cnv_probs',
                    dist.LogNormal(np.log(self.params['cnv_mean']),
                                   self.params['cnv_var']))

        with pyro.plate('data', N, self.params['batch_size']):
            assignment = pyro.sample('assignment',
                                     dist.Categorical(weights),
                                     infer={"enumerate": "parallel"})
            theta = pyro.sample(
                'norm_factor',
                dist.Gamma(self.params['theta_scale'],
                           self.params['theta_rate']))
            for i in pyro.plate('segments2', I):
                pyro.sample(
                    'obs_{}'.format(i),
                    dist.Poisson((Vindex(cc)[assignment, i] * theta * mu[i]) +
                                 1e-8),
                    obs=self.params['data'][i, :])
Esempio n. 11
0
def model_multi_obs_dim(obsmat):
    num_topics = tm.K
    nparticipants = data.shape[0]
    nfeatures = data.shape[1]  # number of rows in each person's matrix
    ncol = data.shape[2]

    # This is a reasonable prior for dirichlet concentrations
    gamma_prior = dist.Gamma(2 * torch.ones(nfeatures, ncol),
                             1 / 3 * torch.ones(nfeatures, ncol)).to_event(2)

    with pyro.plate('topic', num_topics):
        # sample a weight and value for each topic
        topic_weights = pyro.sample("topic_weights",
                                    dist.Gamma(1. / num_topics, 1.))
        topic_a = pyro.sample("topic_a", gamma_prior)
        topic_b = pyro.sample("topic_b", gamma_prior)

    # sample new participant's idiosyncratic topic mixture
    participant_topics = pyro.sample("new_participant_topic",
                                     dist.Dirichlet(topic_weights))

    # we parallelize over the possible topics and pyro automatically weights them by their probs
    transition_topics = pyro.sample("new_transition_topic",
                                    dist.Categorical(participant_topics),
                                    infer={"enumerate": "parallel"})

    # expand assignment to make dimensions match
    for r in np.arange(obsmat.shape[0]):
        rowind = obsmat[r, 1].type(torch.long)
        colind = obsmat[r, 2].type(torch.long)
        print(rowind, colind)
        d = dist.Beta(topic_a[transition_topics, rowind, colind],
                      topic_b[transition_topics, rowind, colind])
        pyro.sample('obs_{}'.format(r), d, obs=obsmat[r, 0])
Esempio n. 12
0
    def guide_hierarchical(self, models, items, obs):
        loc_mu_b_param = pyro.param('loc_mu_b', torch.tensor(0., device=self.device))
        scale_mu_b_param = pyro.param('scale_mu_b', torch.tensor(1.e2, device=self.device), 
                                constraint=constraints.positive)
        loc_mu_theta_param = pyro.param('loc_mu_theta', torch.tensor(0., device=self.device))
        scale_mu_theta_param = pyro.param('scale_mu_theta', torch.tensor(1.e2, device=self.device),
                            constraint=constraints.positive)
        alpha_b_param = pyro.param('alpha_b', torch.tensor(1., device=self.device),
                        constraint=constraints.positive)
        beta_b_param = pyro.param('beta_b', torch.tensor(1., device=self.device),
                        constraint=constraints.positive)
        alpha_theta_param = pyro.param('alpha_theta', torch.tensor(1., device=self.device),
                        constraint=constraints.positive)
        beta_theta_param = pyro.param('beta_theta', torch.tensor(1., device=self.device),
                        constraint=constraints.positive)
        m_theta_param = pyro.param('loc_ability', torch.zeros(self.num_models, device=self.device))
        s_theta_param = pyro.param('scale_ability', torch.ones(self.num_models, device=self.device),
                            constraint=constraints.positive)
        m_b_param = pyro.param('loc_diff', torch.zeros(self.num_items, device=self.device))
        s_b_param = pyro.param('scale_diff', torch.ones(self.num_items, device=self.device),
                                constraint=constraints.positive)


        # sample statements
        pyro.sample('mu_b', dist.Normal(loc_mu_b_param, scale_mu_b_param))
        pyro.sample('u_b', dist.Gamma(alpha_b_param, beta_b_param))
        pyro.sample('mu_theta', dist.Normal(loc_mu_theta_param, scale_mu_theta_param))
        pyro.sample('u_theta', dist.Gamma(alpha_theta_param, beta_theta_param))
        
        with pyro.plate('thetas', self.num_models, device=self.device):
            pyro.sample('theta', dist.Normal(m_theta_param, s_theta_param))
        with pyro.plate('bs', self.num_items, device=self.device):
            pyro.sample('b', dist.Normal(m_b_param, s_b_param))
Esempio n. 13
0
def model(y, BATCHES, SAMPLES):
    theta = pyro.sample(
        'theta'.format(''),
        dist.Normal(
            torch.tensor(0.0) * torch.ones([amb(1)]),
            torch.tensor(100000.0) * torch.ones([amb(1)])))
    tau_between = pyro.sample(
        'tau_between'.format(''),
        dist.Gamma(
            torch.tensor(0.001) * torch.ones([amb(1)]),
            torch.tensor(0.001) * torch.ones([amb(1)])))
    tau_within = pyro.sample(
        'tau_within'.format(''),
        dist.Gamma(
            torch.tensor(0.001) * torch.ones([amb(1)]),
            torch.tensor(0.001) * torch.ones([amb(1)])))
    sigma_between = torch.zeros([amb(1)])
    sigma_within = torch.zeros([amb(1)])
    sigma_between = 1 / torch.sqrt(tau_between)
    sigma_within = 1 / torch.sqrt(tau_within)
    with pyro.iarange('mu_range_'.format(''), BATCHES):
        mu = pyro.sample(
            'mu'.format(''),
            dist.Normal(theta * torch.ones([amb(BATCHES)]),
                        sigma_between * torch.ones([amb(BATCHES)])))
    for n in range(1, BATCHES + 1):
        pyro.sample('obs_{0}_100'.format(n),
                    dist.Normal(mu[n - 1], sigma_within),
                    obs=y[n - 1])
Esempio n. 14
0
    def _param_map_estimates(
            self, data: torch.Tensor,
            chi_ambient: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Calculate MAP estimates of mu, the mean of the true count matrix, and
        lambda, the rate parameter of the Poisson background counts.

        Args:
            data: Dense tensor minibatch of cell by gene count data.
            chi_ambient: Point estimate of inferred ambient gene expression.

        Returns:
            mu_map: Dense tensor of Negative Binomial means for true counts.
            lambda_map: Dense tensor of Poisson rate params for noise counts.
            alpha_map: Dense tensor of Dirichlet concentration params that
                inform the overdispersion of the Negative Binomial.

        """

        # Encode latents.
        enc = self.vi_model.encoder.forward(x=data, chi_ambient=chi_ambient)
        z_map = enc['z']['loc']

        chi_map = self.vi_model.decoder.forward(z_map)
        phi_loc = pyro.param('phi_loc')
        phi_scale = pyro.param('phi_scale')
        phi_conc = phi_loc.pow(2) / phi_scale.pow(2)
        phi_rate = phi_loc / phi_scale.pow(2)
        alpha_map = 1. / dist.Gamma(phi_conc, phi_rate).mean

        y = (enc['p_y'] > 0).float()
        d_empty = dist.LogNormal(loc=pyro.param('d_empty_loc'),
                                 scale=pyro.param('d_empty_scale')).mean
        d_cell = dist.LogNormal(loc=enc['d_loc'],
                                scale=pyro.param('d_cell_scale')).mean
        epsilon = dist.Gamma(enc['epsilon'] * self.vi_model.epsilon_prior,
                             self.vi_model.epsilon_prior).mean

        if self.vi_model.include_rho:
            rho = pyro.param("rho_alpha") / (pyro.param("rho_alpha") +
                                             pyro.param("rho_beta"))
        else:
            rho = None

        # Calculate MAP estimates of mu and lambda.
        mu_map = self.vi_model.calculate_mu(epsilon=epsilon,
                                            d_cell=d_cell,
                                            chi=chi_map,
                                            y=y,
                                            rho=rho)
        lambda_map = self.vi_model.calculate_lambda(
            epsilon=epsilon,
            chi_ambient=chi_ambient,
            d_empty=d_empty,
            y=y,
            d_cell=d_cell,
            rho=rho,
            chi_bar=self.vi_model.avg_gene_expression)

        return {'mu': mu_map, 'lam': lambda_map, 'alpha': alpha_map}
Esempio n. 15
0
 def model(data):
     alpha_prior = pyro.sample("alpha", dist.Gamma(concentration=1.0, rate=1.0))
     beta_prior = pyro.sample("beta", dist.Gamma(concentration=1.0, rate=1.0))
     pyro.sample(
         "x",
         dist.Beta(concentration1=alpha_prior, concentration0=beta_prior),
         obs=data,
     )
Esempio n. 16
0
 def model(data):
     alpha_prior = pyro.sample('alpha', dist.Gamma(concentration=1.,
                                                   rate=1.))
     beta_prior = pyro.sample('beta', dist.Gamma(concentration=1., rate=1.))
     pyro.sample('x',
                 dist.Beta(concentration1=alpha_prior,
                           concentration0=beta_prior),
                 obs=data)
Esempio n. 17
0
    def __init__(self, ode_op, ode_model):
        super(SIRGenModel, self).__init__()
        self._ode_op = ode_op
        self._ode_model = ode_model

        self.ode_params1 = PyroSample(dist.Gamma(2, 1))
        self.ode_params2 = PyroSample(dist.Gamma(2, 1))
        self.ode_params3 = PyroSample(dist.Beta(0.5, 0.5))
def model():
    # Learn MAP of this, strarting from this prior distribution
    log_mean = pyro.sample('average_price_log_mean',
                           dist.Gamma(torch.tensor(7.5), torch.tensor(1.)))
    log_var = pyro.sample('average_price_log_var',
                          dist.Gamma(torch.tensor(7.5), torch.tensor(1.)))

    pyro.sample('average_price', dist.LogNormal(log_mean, log_var))
Esempio n. 19
0
def model(doc_word_data=None, category_data=None, args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", args.num_topics):
        # topic_weights does not seem to come from the usual LDA plate notation, but seems to give an indication of
        # the importance of topics. It might be from the amortized LDA paper.
        topic_weights = pyro.sample("topic_weights",
                                    dist.Gamma(1. / args.num_topics, 1.))
        topic_words = pyro.sample(
            "topic_words",
            dist.Dirichlet(torch.ones(args.num_words) / args.num_words))

    with pyro.plate("categories", args.num_categories):
        category_weights = pyro.sample(
            "category_weights", dist.Gamma(1. / args.num_categories, 1.))
        category_topics = pyro.sample("category_topics",
                                      dist.Dirichlet(topic_weights))

    # Locals.
    with pyro.plate("documents", args.num_docs) as ind:
        if doc_word_data is not None:
            with pyro.util.ignore_jit_warnings():
                assert doc_word_data.shape == (args.num_words_per_doc,
                                               args.num_docs
                                               )  # Forces the 64x1000 shape
            doc_word_data = doc_word_data[:, ind]

        if category_data is not None:
            category_data = category_data[ind]

        category_data = pyro.sample("doc_categories",
                                    dist.Categorical(category_weights),
                                    obs=category_data)

        with pyro.plate("words", args.num_words_per_doc):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            word_topics = pyro.sample("word_topics",
                                      dist.Categorical(
                                          category_topics[category_data]),
                                      infer={"enumerate": "parallel"})
            doc_word_data = pyro.sample("doc_words",
                                        dist.Categorical(
                                            topic_words[word_topics]),
                                        obs=doc_word_data)

    results = {
        "topic_weights": topic_weights,
        "topic_words": topic_words,
        "doc_word_data": doc_word_data,
        "category_weights": category_weights,
        "category_topics": category_topics,
        "category_data": category_data
    }

    return results
Esempio n. 20
0
 def __init__(self, ode_op, ode_model):
     super(PlantModel, self).__init__()
     self._ode_op = ode_op
     self._ode_model = ode_model
     # TODO: Incorporate appropriate priors (cf. MATALB codes from Daewook)
     self.ode_params1 = PyroSample(dist.Gamma(1, 1000))  # dG
     self.ode_params2 = PyroSample(dist.Gamma(1, 1000))  # dP
     self.ode_params3 = PyroSample(dist.Beta(0.5, 0.5))  # G0
     self.ode_params4 = PyroSample(dist.Beta(0.5, 0.5))  # P0
Esempio n. 21
0
def model(data, params):
    # initialize data
    N = data["N"]
    x = data["x"]
    t = data["t"]

    alpha = pyro.sample("alpha", dist.Exponential(1.0))
    beta = pyro.sample("beta", dist.Gamma(0.1, 1.0))
    with pyro.plate('data', N):
        theta = pyro.sample("theta", dist.Gamma(alpha, beta))
        x = pyro.sample("x", dist.Poisson(theta * t), obs=x)
Esempio n. 22
0
def test_gamma_poisson(sample_shape, batch_shape):
    concentration = torch.randn(batch_shape).exp()
    rate = torch.randn(batch_shape).exp()
    nobs = 5
    obs = dist.Poisson(10.).sample((nobs,) + sample_shape + batch_shape).sum(0)

    f = dist.Gamma(concentration, rate)
    g = dist.Gamma(1 + obs, nobs)
    fg, log_normalizer = f.conjugate_update(g)

    x = fg.sample(sample_shape)
    assert_close(f.log_prob(x) + g.log_prob(x), fg.log_prob(x) + log_normalizer)
Esempio n. 23
0
    def guide_hierarchical(self, models, items, obs):
        """Initialize a 1PL guide with hierarchical priors"""
        loc_mu_b_param = pyro.param("loc_mu_b",
                                    torch.tensor(0.0, device=self.device))
        scale_mu_b_param = pyro.param("scale_mu_b",
                                      torch.tensor(1.0e2, device=self.device),
                                      constraint=constraints.positive)
        loc_mu_theta_param = pyro.param("loc_mu_theta",
                                        torch.tensor(0.0, device=self.device))
        scale_mu_theta_param = pyro.param(
            "scale_mu_theta",
            torch.tensor(1.0e2, device=self.device),
            constraint=constraints.positive,
        )
        alpha_b_param = pyro.param("alpha_b",
                                   torch.tensor(1.0, device=self.device),
                                   constraint=constraints.positive)
        beta_b_param = pyro.param("beta_b",
                                  torch.tensor(1.0, device=self.device),
                                  constraint=constraints.positive)
        alpha_theta_param = pyro.param("alpha_theta",
                                       torch.tensor(1.0, device=self.device),
                                       constraint=constraints.positive)
        beta_theta_param = pyro.param("beta_theta",
                                      torch.tensor(1.0, device=self.device),
                                      constraint=constraints.positive)
        m_theta_param = pyro.param(
            "loc_ability", torch.zeros(self.num_subjects, device=self.device))
        s_theta_param = pyro.param(
            "scale_ability",
            torch.ones(self.num_subjects, device=self.device),
            constraint=constraints.positive,
        )
        m_b_param = pyro.param("loc_diff",
                               torch.zeros(self.num_items, device=self.device))
        s_b_param = pyro.param(
            "scale_diff",
            torch.ones(self.num_items, device=self.device),
            constraint=constraints.positive,
        )

        # sample statements
        pyro.sample("mu_b", dist.Normal(loc_mu_b_param, scale_mu_b_param))
        pyro.sample("u_b", dist.Gamma(alpha_b_param, beta_b_param))
        pyro.sample("mu_theta",
                    dist.Normal(loc_mu_theta_param, scale_mu_theta_param))
        pyro.sample("u_theta", dist.Gamma(alpha_theta_param, beta_theta_param))

        with pyro.plate("thetas", self.num_subjects, device=self.device):
            pyro.sample("theta", dist.Normal(m_theta_param, s_theta_param))
        with pyro.plate("bs", self.num_items, device=self.device):
            pyro.sample("b", dist.Normal(m_b_param, s_b_param))
Esempio n. 24
0
 def __init__(self, input_dim, hidden_dim, output_dim):
   super().__init__()
   self.N = output_dim
   self.D = input_dim
   self.M = hidden_dim
   self.aW = pyro.nn.PyroParam(torch.tensor(1.0), constraint=constraints.positive)
   self.bW = pyro.nn.PyroParam(torch.tensor(1.0), constraint=constraints.positive)
   self.aH = pyro.nn.PyroParam(torch.tensor(1.0), constraint=constraints.positive)
   self.bH = pyro.nn.PyroParam(torch.tensor(1.0), constraint=constraints.positive)
   self.W = pyro.nn.PyroSample(lambda self: dist.Gamma(self.aW, self.bW).expand([self.D, self.M]).to_event(2))
   self.H = pyro.nn.PyroSample(lambda self: dist.Gamma(self.aH, self.bH).expand([self.M, self.N]).to_event(2))
   self.d_axis = pyro.plate("d_axis", self.D, dim=-2)
   self.n_axis = pyro.plate("n_axis", self.N, dim=-1)
Esempio n. 25
0
def guide(y,J,sigma):
    arg_1 = pyro.param('arg_1', torch.ones((amb(1))), constraint=constraints.positive)
    arg_2 = pyro.param('arg_2', torch.ones((amb(1))), constraint=constraints.positive)
    mu = pyro.sample('mu'.format(''), dist.Gamma(arg_1,arg_2))
    arg_3 = pyro.param('arg_3', torch.ones((amb(1))), constraint=constraints.positive)
    arg_4 = pyro.param('arg_4', torch.ones((amb(1))), constraint=constraints.positive)
    tau = pyro.sample('tau'.format(''), dist.Gamma(arg_3,arg_4))
    arg_5 = pyro.param('arg_5', torch.ones((amb(J))), constraint=constraints.positive)
    arg_6 = pyro.param('arg_6', torch.ones((amb(J))), constraint=constraints.positive)
    with pyro.iarange('theta_prange'):
        theta = pyro.sample('theta'.format(''), dist.Gamma(arg_5,arg_6))

    pass
def model_multi_obs_grp(obsmat):
    # some parameters can be directly derived from the data passed
    # K = 2
    nparticipants = data.shape[0]
    nfeatures = data.shape[1]  # number of rows in each person's matrix
    ncol = data.shape[2]

    # Background probability of different groups
    if tm.stickbreak:
        # stick breaking process for assigning weights to groups
        with pyro.plate("beta_plate", K - 1):
            beta_mix = pyro.sample("weights", dist.Beta(1, 10))
        weights = tm.mix_weights(beta_mix)
    else:
        weights = pyro.sample('weights',
                              dist.Dirichlet(0.5 * torch.ones(tm.K)))
    # declare model parameters based on whether the data are row-normalized
    if tm.dtype == 'norm':
        pass
#         with pyro.plate('components', K):
#             # concentration parameters
#             concentration = pyro.sample('concentration',
#                                         dist.Gamma(2 * torch.ones(nfeatures,ncol), 1/3 * torch.ones(nfeatures,ncol)).to_event(2))

#         # implementation for the dirichlet based model is not complete!!!!
#         with pyro.plat('data',obsmat.shape[0]):
#             assignment = pyro.sample('assignment', dist.Categorical(weights))
#             #d = dist.Dirichlet(concentration[assignment,:,:].clone().detach()) # .detach() might interfere with backprop
#             d = dist.Dirichlet(concentration[assignment,i,:])
#             pyro.sample('obs', d.to_event(1), obs=obsmat)

    elif tm.dtype == 'raw':
        with pyro.plate('components', tm.K):
            alphas = pyro.sample(
                'alpha',
                dist.Gamma(2 * torch.ones(nfeatures, ncol),
                           1 / 3 * torch.ones(nfeatures, ncol)).to_event(2))
            betas = pyro.sample(
                'beta',
                dist.Gamma(2 * torch.ones(nfeatures, ncol),
                           1 / 3 * torch.ones(nfeatures, ncol)).to_event(2))

        assignment = pyro.sample('assignment', dist.Categorical(weights))
        # expand assignment to make dimensions match
        for r in np.arange(obsmat.shape[0]):
            rowind = obsmat[r, 1].type(torch.long)
            colind = obsmat[r, 2].type(torch.long)
            d = dist.Beta(alphas[assignment, rowind, colind],
                          betas[assignment, rowind, colind])
            pyro.sample('obs_{}'.format(r), d, obs=obsmat[r, 0])
Esempio n. 27
0
 def setUp(self):
     self.alpha = Variable(torch.Tensor([2.4]))
     self.batch_alpha = Variable(torch.Tensor([[2.4], [3.2]]))
     self.batch_beta = Variable(
         torch.Tensor([[np.sqrt(2.4)], [np.sqrt(3.2)]]))
     self.beta = Variable(torch.Tensor([np.sqrt(2.4)]))
     self.test_data = Variable(torch.Tensor([5.5]))
     self.batch_test_data = Variable(torch.Tensor([[5.5], [4.4]]))
     self.dist = dist.Gamma(self.alpha, self.beta)
     self.batch_dist = dist.Gamma(self.batch_alpha, self.batch_beta)
     self.analytic_mean = (self.alpha / self.beta).data.cpu().numpy()[0]
     self.analytic_var = (self.alpha /
                          torch.pow(self.beta, 2.0)).data.cpu().numpy()[0]
     self.n_samples = 50000
Esempio n. 28
0
def model(data, params):
    # initialize data
    BATCHES = data["BATCHES"]
    SAMPLES = data["SAMPLES"]
    y = data["y"]
    # model block
    theta =  pyro.sample("theta", dist.Normal(0.0, 100000.0))
    tau_between =  pyro.sample("tau_between", dist.Gamma(0.001, 0.001))
    sigma_between = 1 / tau_between.sqrt()
    tau_within =  pyro.sample("tau_within", dist.Gamma(0.001, 0.001))
    sigma_within= 1 / tau_within.sqrt()
    with pyro.plate('batches', BATCHES, dim=-2):
        mu =  pyro.sample("mu", dist.Normal(theta, sigma_between))
        with pyro.plate('data', SAMPLES, dim=-1):
            y = pyro.sample('y', dist.Normal(mu, sigma_within), obs=y)
Esempio n. 29
0
def guide(nyear, C, nsite, year):
    arg_1 = pyro.param('arg_1', torch.ones((amb(1))))
    arg_2 = pyro.param('arg_2',
                       torch.ones((amb(1))),
                       constraint=constraints.positive)
    sd_year = pyro.sample('sd_year'.format(''), dist.Normal(arg_1, arg_2))
    arg_3 = pyro.param('arg_3',
                       torch.ones((amb(1))),
                       constraint=constraints.positive)
    arg_4 = pyro.param('arg_4',
                       torch.ones((amb(1))),
                       constraint=constraints.positive)
    sd_alpha = pyro.sample('sd_alpha'.format(''), dist.Gamma(arg_3, arg_4))
    arg_5 = pyro.param('arg_5',
                       torch.ones((amb(1))),
                       constraint=constraints.positive)
    arg_6 = pyro.param('arg_6',
                       torch.ones((amb(1))),
                       constraint=constraints.positive)
    mu = pyro.sample('mu'.format(''), dist.Gamma(arg_5, arg_6))
    arg_7 = pyro.param('arg_7', torch.ones((amb(nsite))))
    arg_8 = pyro.param('arg_8',
                       torch.ones((amb(nsite))),
                       constraint=constraints.positive)
    with pyro.iarange('alpha_prange'):
        alpha = pyro.sample('alpha'.format(''), dist.Normal(arg_7, arg_8))
    arg_9 = pyro.param('arg_9',
                       torch.ones((amb(3))),
                       constraint=constraints.positive)
    arg_10 = pyro.param('arg_10',
                        torch.ones((amb(3))),
                        constraint=constraints.positive)
    with pyro.iarange('beta_prange'):
        beta = pyro.sample('beta'.format(''), dist.Gamma(arg_9, arg_10))
    arg_11 = pyro.param('arg_11',
                        torch.ones((amb(nyear))),
                        constraint=constraints.positive)
    arg_12 = pyro.param('arg_12',
                        torch.ones((amb(nyear))),
                        constraint=constraints.positive)
    with pyro.iarange('eps_prange'):
        eps = pyro.sample('eps'.format(''), dist.Gamma(arg_11, arg_12))
    for i in range(1, nyear + 1):
        pass
    for i in range(1, nyear + 1):
        pass

    pass
Esempio n. 30
0
def model(y,x,N):
    w = pyro.sample('w'.format(''), dist.Beta(Variable(26.072914040168385*torch.ones([amb(1)])),Variable((42.3120851154)*torch.ones([amb(1)]))))
    with pyro.iarange('b_range_'.format(''), N):
        b = pyro.sample('b'.format(''), dist.Gamma(Variable((5.63887222899)*torch.ones([amb(N)])),Variable((40.1978121928)*torch.ones([amb(N)]))))
    with pyro.iarange('p_range_'.format(''), N):
        p = pyro.sample('p'.format(''), dist.Beta(Variable((52.1419233118)*torch.ones([amb(N)])),Variable((83.6618285099)*torch.ones([amb(N)]))))
    pyro.sample('obs__100'.format(), dist.Beta(w*x+b,p), obs=y)