def model(doc_word_data=None, category_data=None, args=None, batch_size=None): # Globals. with pyro.plate("topics", args.num_topics): # topic_weights does not seem to come from the usual LDA plate notation, but seems to give an indication of # the importance of topics. It might be from the amortized LDA paper. topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.)) topic_words = pyro.sample( "topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words)) with pyro.plate("categories", args.num_categories): category_weights = pyro.sample( "category_weights", dist.Gamma(1. / args.num_categories, 1.)) # TODO category weights might not be necessary in our model category_topics = pyro.sample("category_topics", dist.Dirichlet(topic_weights)) doc_category_list = [] doc_word_list = [] # Locals. for index, doc in enumerate(pyro.plate("documents", args.num_docs)): if doc_word_data is not None: cur_doc_word_data = doc_word_data[doc] else: cur_doc_word_data = None if category_data is not None: cur_category_data = category_data[doc] else: cur_category_data = None doc_category_list.append( pyro.sample("doc_categories_{}".format(doc), dist.Categorical(category_weights), obs=cur_category_data)) with pyro.plate("words_{}".format(doc), args.num_words_per_doc[doc]): word_topics = pyro.sample( "word_topics_{}".format(doc), dist.Categorical(category_topics[int( doc_category_list[index].item())])) # TODO Enum parallel/sequential optimizing? doc_word_list.append( pyro.sample("doc_words_{}".format(doc), dist.Categorical(topic_words[word_topics]), obs=cur_doc_word_data)) results = { "topic_weights": topic_weights, "topic_words": topic_words, "doc_word_data": doc_word_list, "category_weights": category_weights, "category_topics": category_topics, "doc_category_data": doc_category_list } return results
def dp_sb_gmm(y, num_components): # Cosntants N = y.shape[0] K = num_components # Priors # NOTE: In pyro, the Gamma distribution is parameterized with shape and rate. # Hence, Gamma(shape, rate) => mean = shape/rate alpha = pyro.sample('alpha', dist.Gamma(1, 10)) with pyro.plate('mixture_weights', K - 1): v = pyro.sample('v', dist.Beta(1, alpha, K - 1)) eta = stickbreak(v) with pyro.plate('components', K): mu = pyro.sample('mu', dist.Normal(0., 3.)) sigma = pyro.sample('sigma', dist.Gamma(1, 10)) with pyro.plate('data', N): # Mixture version. pyro.sample('obs', dist.MixtureSameFamily(dist.Categorical(eta), dist.Normal(mu, sigma)), obs=y)
def logistic_regression_model(x, y, x_a0, x_a1, y_a0_vs_a1, M, beta, tau): # beta is preference observation "inverse temperature" if isinstance(tau, tuple): tau_ = pyro.sample("tau", dist.Gamma(tau[0], tau[1])) else: tau_ = tau if isinstance(beta, tuple): beta_ = pyro.sample("beta", dist.Gamma(beta[0], beta[1])) else: beta_ = beta w_ = pyro.sample( "w", dist.Normal(torch.zeros(M, dtype=torch.double), tau_).independent(1)) if y.size()[0] > 0: # direct observations probs = x @ w_ pyro.sample("y", dist.Bernoulli(logits=probs), obs=y) if y_a0_vs_a1.size()[0] > 0: # pairwise preference observations prob_a0_vs_a1 = (x_a1 - x_a0) @ (beta_ * w_) pyro.sample("y_a0_vs_a1", dist.Bernoulli(logits=prob_a0_vs_a1), obs=y_a0_vs_a1) return w_, tau_, beta_
def guide(y, BATCHES, SAMPLES): arg_1 = pyro.param('arg_1', torch.ones((amb(1)))) arg_2 = pyro.param('arg_2', torch.ones((amb(1))), constraint=constraints.positive) theta = pyro.sample('theta'.format(''), dist.Normal(arg_1, arg_2)) arg_3 = pyro.param('arg_3', torch.ones((amb(1))), constraint=constraints.positive) arg_4 = pyro.param('arg_4', torch.ones((amb(1))), constraint=constraints.positive) tau_between = pyro.sample('tau_between'.format(''), dist.Gamma(arg_3, arg_4)) arg_5 = pyro.param('arg_5', torch.ones((amb(1))), constraint=constraints.positive) arg_6 = pyro.param('arg_6', torch.ones((amb(1))), constraint=constraints.positive) tau_within = pyro.sample('tau_within'.format(''), dist.Gamma(arg_5, arg_6)) arg_7 = pyro.param('arg_7', torch.ones((amb(BATCHES))), constraint=constraints.positive) arg_8 = pyro.param('arg_8', torch.ones((amb(BATCHES))), constraint=constraints.positive) with pyro.iarange('mu_prange'): mu = pyro.sample('mu'.format(''), dist.Gamma(arg_7, arg_8)) for n in range(1, BATCHES + 1): pass
def model_hierarchical(self, models, items, obs): mu_b = pyro.sample( 'mu_b', dist.Normal(torch.tensor(0., device=self.device), torch.tensor(1.e6, device=self.device))) u_b = pyro.sample( 'u_b', dist.Gamma(torch.tensor(1., device=self.device), torch.tensor(1., device=self.device))) mu_theta = pyro.sample( 'mu_theta', dist.Normal(torch.tensor(0., device=self.device), torch.tensor(1.e6, device=self.device))) u_theta = pyro.sample( 'u_theta', dist.Gamma(torch.tensor(1., device=self.device), torch.tensor(1., device=self.device))) mu_a = pyro.sample( 'mu_a', dist.Normal(torch.tensor(0., device=self.device), torch.tensor(1.e6, device=self.device))) u_a = pyro.sample( 'u_a', dist.Gamma(torch.tensor(1., device=self.device), torch.tensor(1., device=self.device))) with pyro.plate('thetas', self.num_models, device=self.device): ability = pyro.sample('theta', dist.Normal(mu_theta, 1. / u_theta)) with pyro.plate('bs', self.num_items, device=self.device): diff = pyro.sample('b', dist.Normal(mu_b, 1. / u_b)) slope = pyro.sample('a', dist.Normal(mu_a, 1. / u_a)) with pyro.plate('observe_data', obs.size(0)): pyro.sample("obs", dist.Bernoulli(logits=slope[items] * (ability[models] - diff[items])), obs=obs)
def logistic_regression_mixture_obs_model_mla(x, y, x_arms, P, M, beta, tau, N_lookahead): # assumes x and y has been augmented with N_lookahead samples for the # lookahead which is always a direct observation # beta is preference observation "inverse temperature" if isinstance(tau, tuple): tau_ = pyro.sample("tau", dist.Gamma(tau[0], tau[1])) else: tau_ = tau if isinstance(beta, tuple): beta_ = pyro.sample("beta", dist.Gamma(beta[0], beta[1])) else: beta_ = beta w_ = pyro.sample( "w", dist.Normal(torch.zeros(M, dtype=torch.double), tau_).independent(1)) a_ = pyro.sample( "alpha", dist.Beta(torch.tensor(1.0, dtype=torch.double), torch.tensor(1.0, dtype=torch.double)), ) logits = x @ w_ N = y.numel() N_not_la = N - N_lookahead if N_not_la > 0: # multistep lookahead observations with torch.no_grad(): x_a0 = x_arms.new_zeros(N_not_la, M) x_a1 = x_arms.new_zeros(N_not_la, M) inds = list(range(P[0].size()[0])) n_branching = len(inds) // 2 inds_0 = inds[0:n_branching] inds_1 = inds[n_branching:len(inds)] logits_tmp = x_arms @ w_ for i in range(len(P)): i0 = torch.argmax(P[i][inds_0, ] @ logits_tmp) i1 = torch.argmax(P[i][inds_1, ] @ logits_tmp) x_a0[i, ] = P[i][inds_0[i0], ] @ x_arms x_a1[i, ] = P[i][inds_1[i1], ] @ x_arms logits_a0_vs_a1 = (x_a1 - x_a0) @ (beta_ * w_) pyro.sample( "y", MixtureObsDistribution(a_, logits[0:N_not_la], logits_a0_vs_a1), obs=y[0:N_not_la], ) if N_lookahead > 0: pyro.sample("y_lookahead", dist.Bernoulli(logits=logits[N_not_la:N]), obs=y[N_not_la:N]) return w_, tau_, beta_
def parametrized_guide(doc_word_data, category_data, args, batch_size=None): # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param("topic_weights_posterior", lambda: torch.ones(args.num_topics), constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", lambda: torch.ones(args.num_topics, args.num_words), constraint=constraints.greater_than(0.5)) with pyro.plate("topics", args.num_topics): pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.)) pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior)) category_weights_posterior = pyro.param( "category_weights_posterior", lambda: torch.ones(args.num_categories), constraint=constraints.positive) category_topics_posterior = pyro.param( "category_topics_posterior", lambda: torch.ones(args.num_categories, args.num_topics), constraint=constraints.greater_than(0.5)) with pyro.plate("categories", args.num_categories): pyro.sample("category_weights", dist.Gamma(category_weights_posterior, 1.)) pyro.sample("category_topics", dist.Dirichlet(category_topics_posterior)) doc_category_posterior = pyro.param("doc_category_posterior", lambda: torch.ones(args.num_topics), constraint=constraints.positive) with pyro.plate("documents", args.num_docs, batch_size) as ind: pyro.sample("doc_categories", dist.Categorical(doc_category_posterior))
def model(self, X, Y=None): ''' ''' E0_mean, E0_std, alpha_emax, beta_emax, alpha_H, beta_H, log10_ec50_mean, log10_ec50_std, alpha_obs, beta_obs = self.get_priors( ) E0 = pyro.sample('E0', dist.Normal(E0_mean, E0_std)) Emax = pyro.sample('Emax', dist.Beta(alpha_emax, beta_emax)) H = pyro.sample('H', dist.Gamma(alpha_H, beta_H)) EC50 = 10**pyro.sample('log_EC50', dist.Normal(log10_ec50_mean, log10_ec50_std)) obs_sigma = pyro.sample("obs_sigma", dist.Gamma(alpha_obs, beta_obs)) obs_mean = E0 + (Emax - E0) / (1 + (EC50 / X)**H) with pyro.plate("data", X.shape[0]): obs = pyro.sample("obs", dist.Normal(obs_mean.squeeze(-1), obs_sigma), obs=Y) return obs_mean
def model_hierarchical(self, models, items, obs): """Initialize a 1PL model with hierarchical priors""" mu_b = pyro.sample( "mu_b", dist.Normal(torch.tensor(0.0, device=self.device), torch.tensor(1.0e6, device=self.device)), ) u_b = pyro.sample( "u_b", dist.Gamma(torch.tensor(1.0, device=self.device), torch.tensor(1.0, device=self.device)), ) mu_theta = pyro.sample( "mu_theta", dist.Normal(torch.tensor(0.0, device=self.device), torch.tensor(1.0e6, device=self.device)), ) u_theta = pyro.sample( "u_theta", dist.Gamma(torch.tensor(1.0, device=self.device), torch.tensor(1.0, device=self.device)), ) with pyro.plate("thetas", self.num_subjects, device=self.device): ability = pyro.sample("theta", dist.Normal(mu_theta, 1.0 / u_theta)) with pyro.plate("bs", self.num_items, device=self.device): diff = pyro.sample("b", dist.Normal(mu_b, 1.0 / u_b)) with pyro.plate("observe_data", obs.size(0)): pyro.sample("obs", dist.Bernoulli(logits=ability[models] - diff[items]), obs=obs)
def model(self, *args, **kwargs): I, N = self.params['data'].shape weights = pyro.sample('mixture_weights', dist.Dirichlet(self.params['mixture'])) with pyro.plate('segments', I): mu = pyro.sample( 'gene_basal', dist.Gamma(self.params['theta_scale'], self.params['theta_rate'])) with pyro.plate('components', self.params['K']): cc = pyro.sample( 'cnv_probs', dist.LogNormal(np.log(self.params['cnv_mean']), self.params['cnv_var'])) with pyro.plate('data', N, self.params['batch_size']): assignment = pyro.sample('assignment', dist.Categorical(weights), infer={"enumerate": "parallel"}) theta = pyro.sample( 'norm_factor', dist.Gamma(self.params['theta_scale'], self.params['theta_rate'])) for i in pyro.plate('segments2', I): pyro.sample( 'obs_{}'.format(i), dist.Poisson((Vindex(cc)[assignment, i] * theta * mu[i]) + 1e-8), obs=self.params['data'][i, :])
def model_multi_obs_dim(obsmat): num_topics = tm.K nparticipants = data.shape[0] nfeatures = data.shape[1] # number of rows in each person's matrix ncol = data.shape[2] # This is a reasonable prior for dirichlet concentrations gamma_prior = dist.Gamma(2 * torch.ones(nfeatures, ncol), 1 / 3 * torch.ones(nfeatures, ncol)).to_event(2) with pyro.plate('topic', num_topics): # sample a weight and value for each topic topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / num_topics, 1.)) topic_a = pyro.sample("topic_a", gamma_prior) topic_b = pyro.sample("topic_b", gamma_prior) # sample new participant's idiosyncratic topic mixture participant_topics = pyro.sample("new_participant_topic", dist.Dirichlet(topic_weights)) # we parallelize over the possible topics and pyro automatically weights them by their probs transition_topics = pyro.sample("new_transition_topic", dist.Categorical(participant_topics), infer={"enumerate": "parallel"}) # expand assignment to make dimensions match for r in np.arange(obsmat.shape[0]): rowind = obsmat[r, 1].type(torch.long) colind = obsmat[r, 2].type(torch.long) print(rowind, colind) d = dist.Beta(topic_a[transition_topics, rowind, colind], topic_b[transition_topics, rowind, colind]) pyro.sample('obs_{}'.format(r), d, obs=obsmat[r, 0])
def guide_hierarchical(self, models, items, obs): loc_mu_b_param = pyro.param('loc_mu_b', torch.tensor(0., device=self.device)) scale_mu_b_param = pyro.param('scale_mu_b', torch.tensor(1.e2, device=self.device), constraint=constraints.positive) loc_mu_theta_param = pyro.param('loc_mu_theta', torch.tensor(0., device=self.device)) scale_mu_theta_param = pyro.param('scale_mu_theta', torch.tensor(1.e2, device=self.device), constraint=constraints.positive) alpha_b_param = pyro.param('alpha_b', torch.tensor(1., device=self.device), constraint=constraints.positive) beta_b_param = pyro.param('beta_b', torch.tensor(1., device=self.device), constraint=constraints.positive) alpha_theta_param = pyro.param('alpha_theta', torch.tensor(1., device=self.device), constraint=constraints.positive) beta_theta_param = pyro.param('beta_theta', torch.tensor(1., device=self.device), constraint=constraints.positive) m_theta_param = pyro.param('loc_ability', torch.zeros(self.num_models, device=self.device)) s_theta_param = pyro.param('scale_ability', torch.ones(self.num_models, device=self.device), constraint=constraints.positive) m_b_param = pyro.param('loc_diff', torch.zeros(self.num_items, device=self.device)) s_b_param = pyro.param('scale_diff', torch.ones(self.num_items, device=self.device), constraint=constraints.positive) # sample statements pyro.sample('mu_b', dist.Normal(loc_mu_b_param, scale_mu_b_param)) pyro.sample('u_b', dist.Gamma(alpha_b_param, beta_b_param)) pyro.sample('mu_theta', dist.Normal(loc_mu_theta_param, scale_mu_theta_param)) pyro.sample('u_theta', dist.Gamma(alpha_theta_param, beta_theta_param)) with pyro.plate('thetas', self.num_models, device=self.device): pyro.sample('theta', dist.Normal(m_theta_param, s_theta_param)) with pyro.plate('bs', self.num_items, device=self.device): pyro.sample('b', dist.Normal(m_b_param, s_b_param))
def model(y, BATCHES, SAMPLES): theta = pyro.sample( 'theta'.format(''), dist.Normal( torch.tensor(0.0) * torch.ones([amb(1)]), torch.tensor(100000.0) * torch.ones([amb(1)]))) tau_between = pyro.sample( 'tau_between'.format(''), dist.Gamma( torch.tensor(0.001) * torch.ones([amb(1)]), torch.tensor(0.001) * torch.ones([amb(1)]))) tau_within = pyro.sample( 'tau_within'.format(''), dist.Gamma( torch.tensor(0.001) * torch.ones([amb(1)]), torch.tensor(0.001) * torch.ones([amb(1)]))) sigma_between = torch.zeros([amb(1)]) sigma_within = torch.zeros([amb(1)]) sigma_between = 1 / torch.sqrt(tau_between) sigma_within = 1 / torch.sqrt(tau_within) with pyro.iarange('mu_range_'.format(''), BATCHES): mu = pyro.sample( 'mu'.format(''), dist.Normal(theta * torch.ones([amb(BATCHES)]), sigma_between * torch.ones([amb(BATCHES)]))) for n in range(1, BATCHES + 1): pyro.sample('obs_{0}_100'.format(n), dist.Normal(mu[n - 1], sigma_within), obs=y[n - 1])
def _param_map_estimates( self, data: torch.Tensor, chi_ambient: torch.Tensor) -> Dict[str, torch.Tensor]: """Calculate MAP estimates of mu, the mean of the true count matrix, and lambda, the rate parameter of the Poisson background counts. Args: data: Dense tensor minibatch of cell by gene count data. chi_ambient: Point estimate of inferred ambient gene expression. Returns: mu_map: Dense tensor of Negative Binomial means for true counts. lambda_map: Dense tensor of Poisson rate params for noise counts. alpha_map: Dense tensor of Dirichlet concentration params that inform the overdispersion of the Negative Binomial. """ # Encode latents. enc = self.vi_model.encoder.forward(x=data, chi_ambient=chi_ambient) z_map = enc['z']['loc'] chi_map = self.vi_model.decoder.forward(z_map) phi_loc = pyro.param('phi_loc') phi_scale = pyro.param('phi_scale') phi_conc = phi_loc.pow(2) / phi_scale.pow(2) phi_rate = phi_loc / phi_scale.pow(2) alpha_map = 1. / dist.Gamma(phi_conc, phi_rate).mean y = (enc['p_y'] > 0).float() d_empty = dist.LogNormal(loc=pyro.param('d_empty_loc'), scale=pyro.param('d_empty_scale')).mean d_cell = dist.LogNormal(loc=enc['d_loc'], scale=pyro.param('d_cell_scale')).mean epsilon = dist.Gamma(enc['epsilon'] * self.vi_model.epsilon_prior, self.vi_model.epsilon_prior).mean if self.vi_model.include_rho: rho = pyro.param("rho_alpha") / (pyro.param("rho_alpha") + pyro.param("rho_beta")) else: rho = None # Calculate MAP estimates of mu and lambda. mu_map = self.vi_model.calculate_mu(epsilon=epsilon, d_cell=d_cell, chi=chi_map, y=y, rho=rho) lambda_map = self.vi_model.calculate_lambda( epsilon=epsilon, chi_ambient=chi_ambient, d_empty=d_empty, y=y, d_cell=d_cell, rho=rho, chi_bar=self.vi_model.avg_gene_expression) return {'mu': mu_map, 'lam': lambda_map, 'alpha': alpha_map}
def model(data): alpha_prior = pyro.sample("alpha", dist.Gamma(concentration=1.0, rate=1.0)) beta_prior = pyro.sample("beta", dist.Gamma(concentration=1.0, rate=1.0)) pyro.sample( "x", dist.Beta(concentration1=alpha_prior, concentration0=beta_prior), obs=data, )
def model(data): alpha_prior = pyro.sample('alpha', dist.Gamma(concentration=1., rate=1.)) beta_prior = pyro.sample('beta', dist.Gamma(concentration=1., rate=1.)) pyro.sample('x', dist.Beta(concentration1=alpha_prior, concentration0=beta_prior), obs=data)
def __init__(self, ode_op, ode_model): super(SIRGenModel, self).__init__() self._ode_op = ode_op self._ode_model = ode_model self.ode_params1 = PyroSample(dist.Gamma(2, 1)) self.ode_params2 = PyroSample(dist.Gamma(2, 1)) self.ode_params3 = PyroSample(dist.Beta(0.5, 0.5))
def model(): # Learn MAP of this, strarting from this prior distribution log_mean = pyro.sample('average_price_log_mean', dist.Gamma(torch.tensor(7.5), torch.tensor(1.))) log_var = pyro.sample('average_price_log_var', dist.Gamma(torch.tensor(7.5), torch.tensor(1.))) pyro.sample('average_price', dist.LogNormal(log_mean, log_var))
def model(doc_word_data=None, category_data=None, args=None, batch_size=None): # Globals. with pyro.plate("topics", args.num_topics): # topic_weights does not seem to come from the usual LDA plate notation, but seems to give an indication of # the importance of topics. It might be from the amortized LDA paper. topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.)) topic_words = pyro.sample( "topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words)) with pyro.plate("categories", args.num_categories): category_weights = pyro.sample( "category_weights", dist.Gamma(1. / args.num_categories, 1.)) category_topics = pyro.sample("category_topics", dist.Dirichlet(topic_weights)) # Locals. with pyro.plate("documents", args.num_docs) as ind: if doc_word_data is not None: with pyro.util.ignore_jit_warnings(): assert doc_word_data.shape == (args.num_words_per_doc, args.num_docs ) # Forces the 64x1000 shape doc_word_data = doc_word_data[:, ind] if category_data is not None: category_data = category_data[ind] category_data = pyro.sample("doc_categories", dist.Categorical(category_weights), obs=category_data) with pyro.plate("words", args.num_words_per_doc): # The word_topics variable is marginalized out during inference, # achieved by specifying infer={"enumerate": "parallel"} and using # TraceEnum_ELBO for inference. Thus we can ignore this variable in # the guide. word_topics = pyro.sample("word_topics", dist.Categorical( category_topics[category_data]), infer={"enumerate": "parallel"}) doc_word_data = pyro.sample("doc_words", dist.Categorical( topic_words[word_topics]), obs=doc_word_data) results = { "topic_weights": topic_weights, "topic_words": topic_words, "doc_word_data": doc_word_data, "category_weights": category_weights, "category_topics": category_topics, "category_data": category_data } return results
def __init__(self, ode_op, ode_model): super(PlantModel, self).__init__() self._ode_op = ode_op self._ode_model = ode_model # TODO: Incorporate appropriate priors (cf. MATALB codes from Daewook) self.ode_params1 = PyroSample(dist.Gamma(1, 1000)) # dG self.ode_params2 = PyroSample(dist.Gamma(1, 1000)) # dP self.ode_params3 = PyroSample(dist.Beta(0.5, 0.5)) # G0 self.ode_params4 = PyroSample(dist.Beta(0.5, 0.5)) # P0
def model(data, params): # initialize data N = data["N"] x = data["x"] t = data["t"] alpha = pyro.sample("alpha", dist.Exponential(1.0)) beta = pyro.sample("beta", dist.Gamma(0.1, 1.0)) with pyro.plate('data', N): theta = pyro.sample("theta", dist.Gamma(alpha, beta)) x = pyro.sample("x", dist.Poisson(theta * t), obs=x)
def test_gamma_poisson(sample_shape, batch_shape): concentration = torch.randn(batch_shape).exp() rate = torch.randn(batch_shape).exp() nobs = 5 obs = dist.Poisson(10.).sample((nobs,) + sample_shape + batch_shape).sum(0) f = dist.Gamma(concentration, rate) g = dist.Gamma(1 + obs, nobs) fg, log_normalizer = f.conjugate_update(g) x = fg.sample(sample_shape) assert_close(f.log_prob(x) + g.log_prob(x), fg.log_prob(x) + log_normalizer)
def guide_hierarchical(self, models, items, obs): """Initialize a 1PL guide with hierarchical priors""" loc_mu_b_param = pyro.param("loc_mu_b", torch.tensor(0.0, device=self.device)) scale_mu_b_param = pyro.param("scale_mu_b", torch.tensor(1.0e2, device=self.device), constraint=constraints.positive) loc_mu_theta_param = pyro.param("loc_mu_theta", torch.tensor(0.0, device=self.device)) scale_mu_theta_param = pyro.param( "scale_mu_theta", torch.tensor(1.0e2, device=self.device), constraint=constraints.positive, ) alpha_b_param = pyro.param("alpha_b", torch.tensor(1.0, device=self.device), constraint=constraints.positive) beta_b_param = pyro.param("beta_b", torch.tensor(1.0, device=self.device), constraint=constraints.positive) alpha_theta_param = pyro.param("alpha_theta", torch.tensor(1.0, device=self.device), constraint=constraints.positive) beta_theta_param = pyro.param("beta_theta", torch.tensor(1.0, device=self.device), constraint=constraints.positive) m_theta_param = pyro.param( "loc_ability", torch.zeros(self.num_subjects, device=self.device)) s_theta_param = pyro.param( "scale_ability", torch.ones(self.num_subjects, device=self.device), constraint=constraints.positive, ) m_b_param = pyro.param("loc_diff", torch.zeros(self.num_items, device=self.device)) s_b_param = pyro.param( "scale_diff", torch.ones(self.num_items, device=self.device), constraint=constraints.positive, ) # sample statements pyro.sample("mu_b", dist.Normal(loc_mu_b_param, scale_mu_b_param)) pyro.sample("u_b", dist.Gamma(alpha_b_param, beta_b_param)) pyro.sample("mu_theta", dist.Normal(loc_mu_theta_param, scale_mu_theta_param)) pyro.sample("u_theta", dist.Gamma(alpha_theta_param, beta_theta_param)) with pyro.plate("thetas", self.num_subjects, device=self.device): pyro.sample("theta", dist.Normal(m_theta_param, s_theta_param)) with pyro.plate("bs", self.num_items, device=self.device): pyro.sample("b", dist.Normal(m_b_param, s_b_param))
def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() self.N = output_dim self.D = input_dim self.M = hidden_dim self.aW = pyro.nn.PyroParam(torch.tensor(1.0), constraint=constraints.positive) self.bW = pyro.nn.PyroParam(torch.tensor(1.0), constraint=constraints.positive) self.aH = pyro.nn.PyroParam(torch.tensor(1.0), constraint=constraints.positive) self.bH = pyro.nn.PyroParam(torch.tensor(1.0), constraint=constraints.positive) self.W = pyro.nn.PyroSample(lambda self: dist.Gamma(self.aW, self.bW).expand([self.D, self.M]).to_event(2)) self.H = pyro.nn.PyroSample(lambda self: dist.Gamma(self.aH, self.bH).expand([self.M, self.N]).to_event(2)) self.d_axis = pyro.plate("d_axis", self.D, dim=-2) self.n_axis = pyro.plate("n_axis", self.N, dim=-1)
def guide(y,J,sigma): arg_1 = pyro.param('arg_1', torch.ones((amb(1))), constraint=constraints.positive) arg_2 = pyro.param('arg_2', torch.ones((amb(1))), constraint=constraints.positive) mu = pyro.sample('mu'.format(''), dist.Gamma(arg_1,arg_2)) arg_3 = pyro.param('arg_3', torch.ones((amb(1))), constraint=constraints.positive) arg_4 = pyro.param('arg_4', torch.ones((amb(1))), constraint=constraints.positive) tau = pyro.sample('tau'.format(''), dist.Gamma(arg_3,arg_4)) arg_5 = pyro.param('arg_5', torch.ones((amb(J))), constraint=constraints.positive) arg_6 = pyro.param('arg_6', torch.ones((amb(J))), constraint=constraints.positive) with pyro.iarange('theta_prange'): theta = pyro.sample('theta'.format(''), dist.Gamma(arg_5,arg_6)) pass
def model_multi_obs_grp(obsmat): # some parameters can be directly derived from the data passed # K = 2 nparticipants = data.shape[0] nfeatures = data.shape[1] # number of rows in each person's matrix ncol = data.shape[2] # Background probability of different groups if tm.stickbreak: # stick breaking process for assigning weights to groups with pyro.plate("beta_plate", K - 1): beta_mix = pyro.sample("weights", dist.Beta(1, 10)) weights = tm.mix_weights(beta_mix) else: weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(tm.K))) # declare model parameters based on whether the data are row-normalized if tm.dtype == 'norm': pass # with pyro.plate('components', K): # # concentration parameters # concentration = pyro.sample('concentration', # dist.Gamma(2 * torch.ones(nfeatures,ncol), 1/3 * torch.ones(nfeatures,ncol)).to_event(2)) # # implementation for the dirichlet based model is not complete!!!! # with pyro.plat('data',obsmat.shape[0]): # assignment = pyro.sample('assignment', dist.Categorical(weights)) # #d = dist.Dirichlet(concentration[assignment,:,:].clone().detach()) # .detach() might interfere with backprop # d = dist.Dirichlet(concentration[assignment,i,:]) # pyro.sample('obs', d.to_event(1), obs=obsmat) elif tm.dtype == 'raw': with pyro.plate('components', tm.K): alphas = pyro.sample( 'alpha', dist.Gamma(2 * torch.ones(nfeatures, ncol), 1 / 3 * torch.ones(nfeatures, ncol)).to_event(2)) betas = pyro.sample( 'beta', dist.Gamma(2 * torch.ones(nfeatures, ncol), 1 / 3 * torch.ones(nfeatures, ncol)).to_event(2)) assignment = pyro.sample('assignment', dist.Categorical(weights)) # expand assignment to make dimensions match for r in np.arange(obsmat.shape[0]): rowind = obsmat[r, 1].type(torch.long) colind = obsmat[r, 2].type(torch.long) d = dist.Beta(alphas[assignment, rowind, colind], betas[assignment, rowind, colind]) pyro.sample('obs_{}'.format(r), d, obs=obsmat[r, 0])
def setUp(self): self.alpha = Variable(torch.Tensor([2.4])) self.batch_alpha = Variable(torch.Tensor([[2.4], [3.2]])) self.batch_beta = Variable( torch.Tensor([[np.sqrt(2.4)], [np.sqrt(3.2)]])) self.beta = Variable(torch.Tensor([np.sqrt(2.4)])) self.test_data = Variable(torch.Tensor([5.5])) self.batch_test_data = Variable(torch.Tensor([[5.5], [4.4]])) self.dist = dist.Gamma(self.alpha, self.beta) self.batch_dist = dist.Gamma(self.batch_alpha, self.batch_beta) self.analytic_mean = (self.alpha / self.beta).data.cpu().numpy()[0] self.analytic_var = (self.alpha / torch.pow(self.beta, 2.0)).data.cpu().numpy()[0] self.n_samples = 50000
def model(data, params): # initialize data BATCHES = data["BATCHES"] SAMPLES = data["SAMPLES"] y = data["y"] # model block theta = pyro.sample("theta", dist.Normal(0.0, 100000.0)) tau_between = pyro.sample("tau_between", dist.Gamma(0.001, 0.001)) sigma_between = 1 / tau_between.sqrt() tau_within = pyro.sample("tau_within", dist.Gamma(0.001, 0.001)) sigma_within= 1 / tau_within.sqrt() with pyro.plate('batches', BATCHES, dim=-2): mu = pyro.sample("mu", dist.Normal(theta, sigma_between)) with pyro.plate('data', SAMPLES, dim=-1): y = pyro.sample('y', dist.Normal(mu, sigma_within), obs=y)
def guide(nyear, C, nsite, year): arg_1 = pyro.param('arg_1', torch.ones((amb(1)))) arg_2 = pyro.param('arg_2', torch.ones((amb(1))), constraint=constraints.positive) sd_year = pyro.sample('sd_year'.format(''), dist.Normal(arg_1, arg_2)) arg_3 = pyro.param('arg_3', torch.ones((amb(1))), constraint=constraints.positive) arg_4 = pyro.param('arg_4', torch.ones((amb(1))), constraint=constraints.positive) sd_alpha = pyro.sample('sd_alpha'.format(''), dist.Gamma(arg_3, arg_4)) arg_5 = pyro.param('arg_5', torch.ones((amb(1))), constraint=constraints.positive) arg_6 = pyro.param('arg_6', torch.ones((amb(1))), constraint=constraints.positive) mu = pyro.sample('mu'.format(''), dist.Gamma(arg_5, arg_6)) arg_7 = pyro.param('arg_7', torch.ones((amb(nsite)))) arg_8 = pyro.param('arg_8', torch.ones((amb(nsite))), constraint=constraints.positive) with pyro.iarange('alpha_prange'): alpha = pyro.sample('alpha'.format(''), dist.Normal(arg_7, arg_8)) arg_9 = pyro.param('arg_9', torch.ones((amb(3))), constraint=constraints.positive) arg_10 = pyro.param('arg_10', torch.ones((amb(3))), constraint=constraints.positive) with pyro.iarange('beta_prange'): beta = pyro.sample('beta'.format(''), dist.Gamma(arg_9, arg_10)) arg_11 = pyro.param('arg_11', torch.ones((amb(nyear))), constraint=constraints.positive) arg_12 = pyro.param('arg_12', torch.ones((amb(nyear))), constraint=constraints.positive) with pyro.iarange('eps_prange'): eps = pyro.sample('eps'.format(''), dist.Gamma(arg_11, arg_12)) for i in range(1, nyear + 1): pass for i in range(1, nyear + 1): pass pass
def model(y,x,N): w = pyro.sample('w'.format(''), dist.Beta(Variable(26.072914040168385*torch.ones([amb(1)])),Variable((42.3120851154)*torch.ones([amb(1)])))) with pyro.iarange('b_range_'.format(''), N): b = pyro.sample('b'.format(''), dist.Gamma(Variable((5.63887222899)*torch.ones([amb(N)])),Variable((40.1978121928)*torch.ones([amb(N)])))) with pyro.iarange('p_range_'.format(''), N): p = pyro.sample('p'.format(''), dist.Beta(Variable((52.1419233118)*torch.ones([amb(N)])),Variable((83.6618285099)*torch.ones([amb(N)])))) pyro.sample('obs__100'.format(), dist.Beta(w*x+b,p), obs=y)