def test_shape_augmented_gamma_elbo(alpha, beta): num_samples = 100000 alphas = torch.full((num_samples, 1), alpha).requires_grad_() betas = torch.full((num_samples, 1), beta).requires_grad_() model = Gamma(torch.ones(num_samples, 1), torch.ones(num_samples, 1)) guide1 = Gamma(alphas, betas) guide2 = ShapeAugmentedGamma(alphas, betas) # implemented using Rejector grads = [] for guide in [guide1, guide2]: grads.append(compute_elbo_grad(model, guide, [alphas, betas])) expected, actual = grads expected = [g.mean() for g in expected] actual = [g.mean() for g in actual] scale = [(1 + abs(g)) for g in expected] assert_equal( actual[0] / scale[0], expected[0] / scale[0], prec=0.05, msg="bad grad for alpha", ) assert_equal(actual[1] / scale[1], expected[1] / scale[1], prec=0.05, msg="bad grad for beta")
def init(self, state, N): state["λ"] = sample("λ", Gamma(1., 1./1.)) state["μ"] = sample("μ", Gamma(1., 1./0.5)) f = (N-1)*log(tensor(2)) for n in range(2, N+1): f -= log(tensor(n)) factor("factor_orient_labeled", f)
def test_standard_gamma_elbo(alpha): num_samples = 100000 alphas = torch.tensor(torch.tensor(alpha).expand(num_samples, 1), requires_grad=True) betas = torch.ones(num_samples, 1) model = Gamma(torch.ones(num_samples, 1), betas) guide1 = Gamma(alphas, betas) guide2 = RejectionStandardGamma(alphas) # implemented using Rejector grads = [] for guide in [guide1, guide2]: grads.append(compute_elbo_grad(model, guide, [alphas])[0].data) expected, actual = grads assert_equal(actual.mean(), expected.mean(), prec=0.01, msg='bad grad for alpha')
def main(data=None, args=None, batch_size=None): data = torch.reshape(data, [64, 1000]) # Globals. with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(1. / 8, 1.)) # shape = [8] + [1024] topic_words = pyro.sample("topic_words", Dirichlet(torch.ones(1024) / 1024)) # Locals. with pyro.plate("documents", 1000) as ind: # shape = [64, 32] data = data[:, ind] # shape = [32] + [8] doc_topics = pyro.sample("doc_topics", Dirichlet(topic_weights)) with pyro.plate("words", 64): # The word_topics variable is marginalized out during inference, # achieved by specifying infer={"enumerate": "parallel"} and using # TraceEnum_ELBO for inference. Thus we can ignore this variable in # the guide. # shape = [64, 32] + [] word_topics = pyro.sample("word_topics", Categorical(doc_topics), infer={"enumerate": "parallel"}) # shape = [64, 32] + [] data = pyro.sample("doc_words", Categorical(topic_words[word_topics]), obs=data)
def generate_model(data=None, args=None, batch_size=None): # Globals. with pyro.plate("topics", 8): topic_weights = pyro.sample("topic_weights", Gamma(1. / 8, 1.)) topic_words = pyro.sample("topic_words", Dirichlet(torch.ones(1024) / 1024)) # Locals. with pyro.plate("documents", 1000) as ind: if data is not None: data = data[:, ind] doc_topics = pyro.sample("doc_topics", Dirichlet(topic_weights)) with pyro.plate("words", 64): # The word_topics variable is marginalized out during inference, # achieved by specifying infer={"enumerate": "parallel"} and using # TraceEnum_ELBO for inference. Thus we can ignore this variable in # the guide. word_topics = pyro.sample("word_topics", Categorical(doc_topics), infer={"enumerate": "parallel"}) data = pyro.sample("doc_words", Categorical(topic_words[word_topics]), obs=data) return topic_weights, topic_words, data
def sample_ws(name, width): alpha_w_q = pyro.param("alpha_w_q_%s" % name, lambda: rand_tensor((width), self.alpha_init, self.sigma_init)) mean_w_q = pyro.param("mean_w_q_%s" % name, lambda: rand_tensor((width), self.mean_init, self.sigma_init)) alpha_w_q, mean_w_q = softplus(alpha_w_q), softplus(mean_w_q) pyro.sample("w_%s" % name, Gamma(alpha_w_q, alpha_w_q / mean_w_q))
def sample_zs(name, width): alpha_z_q = pyro.param("alpha_z_q_%s" % name, lambda: rand_tensor((x_size, width), self.alpha_init, self.sigma_init)) mean_z_q = pyro.param("mean_z_q_%s" % name, lambda: rand_tensor((x_size, width), self.mean_init, self.sigma_init)) alpha_z_q, mean_z_q = softplus(alpha_z_q), softplus(mean_z_q) pyro.sample("z_%s" % name, Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1))
def model(): lambda_latent = pyro.sample("lambda_latent", Gamma(self.alpha0, self.beta0)) x_dist = Exponential(lambda_latent) pyro.observe("obs0", x_dist, self.data[0]) pyro.observe("obs1", x_dist, self.data[1]) return lambda_latent
def model(n_samples=None): with pyro.plate('observations', n_samples): thickness = pyro.sample('thickness', Gamma(10., 5.)) loc = (thickness - 2.5) * 20 slant = pyro.sample('slant', Normal(loc, 1.)) return slant, thickness
def guide(): alpha_q_log = pyro.param( "alpha_q_log", Variable(self.alpha_q_log_0, requires_grad=True)) beta_q_log = pyro.param( "beta_q_log", Variable(self.beta_q_log_0, requires_grad=True)) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("lambda_latent", Gamma(alpha_q, beta_q))
def model(): lambda_latent = pyro.sample("lambda_latent", Gamma(self.alpha0, self.beta0)) x_dist = Poisson(lambda_latent) # x0 = pyro.observe("obs0", x_dist, self.data[0]) pyro.map_data(self.data, lambda i, x: pyro.observe("obs", x_dist, x), batch_size=3) return lambda_latent
def test_gamma_elbo(alpha, beta): num_samples = 100000 alphas = torch.tensor(torch.tensor(alpha).expand(num_samples, 1), requires_grad=True) betas = torch.tensor(torch.tensor(beta).expand(num_samples, 1), requires_grad=True) model = Gamma(torch.ones(num_samples, 1), torch.ones(num_samples, 1)) guide1 = Gamma(alphas, betas) guide2 = RejectionGamma(alphas, betas) # implemented using Rejector grads = [] for guide in [guide1, guide2]: grads.append(compute_elbo_grad(model, guide, [alphas, betas])) expected, actual = grads expected = [g.mean() for g in expected] actual = [g.mean() for g in actual] scale = [(1 + abs(g)) for g in expected] assert_equal(actual[0] / scale[0], expected[0] / scale[0], prec=0.01, msg='bad grad for alpha') assert_equal(actual[1] / scale[1], expected[1] / scale[1], prec=0.01, msg='bad grad for beta')
def guide(): alpha_q_log = pyro.param( "alpha_q_log", Variable(self.log_alpha_n.data + 0.17, requires_grad=True)) beta_q_log = pyro.param( "beta_q_log", Variable(self.log_beta_n.data - 0.143, requires_grad=True)) alpha_q, beta_q = torch.exp(alpha_q_log), torch.exp(beta_q_log) pyro.sample("lambda_latent", Gamma(alpha_q, beta_q)) pyro.map_data(self.data, lambda i, x: None, batch_size=3)
def init(self, state, λ, s0, i0, r0): obs = lambda value: full((state._num_particles, ), value) state["λ"] = sample("λ", Gamma(2., 1. / 5.0), obs=obs(λ)) state["δ"] = sample("δ", Beta(2., 2.)) state["γ"] = sample("γ", Beta(2., 2.)) state["s0"] = obs(s0) state["i0"] = obs(i0) state["r0"] = obs(r0)
def model(n_samples=None, scale=2.): with pyro.plate('observations', n_samples): thickness = pyro.sample('thickness', Gamma(10., 5.)) loc = (thickness - 2.5) * 2 transforms = ComposeTransform([SigmoidTransform(), AffineTransform(10, 15)]) width = pyro.sample('width', TransformedDistribution(Normal(loc, scale), transforms)) return thickness, width
def model(): alpha_p_log = pyro.param( "alpha_p_log", Variable(self.alpha_p_log_0, requires_grad=True)) beta_p_log = pyro.param( "beta_p_log", Variable(self.beta_p_log_0, requires_grad=True)) alpha_p, beta_p = torch.exp(alpha_p_log), torch.exp(beta_p_log) lambda_latent = pyro.sample("lambda_latent", Gamma(alpha_p, beta_p)) x_dist = Poisson(lambda_latent) pyro.observe("obs", x_dist, self.data) return lambda_latent
def model(self, x): x_size = x.size(0) # sample the global weights with pyro.plate("w_top_plate", self.top_width * self.mid_width): w_top = pyro.sample("w_top", Gamma(self.alpha_w, self.beta_w)) with pyro.plate("w_mid_plate", self.mid_width * self.bottom_width): w_mid = pyro.sample("w_mid", Gamma(self.alpha_w, self.beta_w)) with pyro.plate("w_bottom_plate", self.bottom_width * self.image_size): w_bottom = pyro.sample("w_bottom", Gamma(self.alpha_w, self.beta_w)) # sample the local latent random variables # (the plate encodes the fact that the z's for different datapoints are conditionally independent) with pyro.plate("data", x_size): z_top = pyro.sample("z_top", Gamma(self.alpha_z, self.beta_z).expand([self.top_width]).to_event(1)) # note that we need to use matmul (batch matrix multiplication) as well as appropriate reshaping # to make sure our code is fully vectorized w_top = w_top.reshape(self.top_width, self.mid_width) if w_top.dim() == 1 else \ w_top.reshape(-1, self.top_width, self.mid_width) mean_mid = torch.matmul(z_top, w_top) z_mid = pyro.sample("z_mid", Gamma(self.alpha_z, self.beta_z / mean_mid).to_event(1)) w_mid = w_mid.reshape(self.mid_width, self.bottom_width) if w_mid.dim() == 1 else \ w_mid.reshape(-1, self.mid_width, self.bottom_width) mean_bottom = torch.matmul(z_mid, w_mid) z_bottom = pyro.sample("z_bottom", Gamma(self.alpha_z, self.beta_z / mean_bottom).to_event(1)) w_bottom = w_bottom.reshape(self.bottom_width, self.image_size) if w_bottom.dim() == 1 else \ w_bottom.reshape(-1, self.bottom_width, self.image_size) mean_obs = torch.matmul(z_bottom, w_bottom) # observe the data using a poisson likelihood pyro.sample('obs', Poisson(mean_obs).to_event(1), obs=x)
def model(x): x = torch.reshape(x, [320, 4096]) with pyro.plate("w_top_plate", 4000): w_top = pyro.sample("w_top", Gamma(alpha_w, beta_w)) with pyro.plate("w_mid_plate", 600): w_mid = pyro.sample("w_mid", Gamma(alpha_w, beta_w)) with pyro.plate("w_bottom_plate", 61440): w_bottom = pyro.sample("w_bottom", Gamma(alpha_w, beta_w)) with pyro.plate("data", 320): z_top = pyro.sample( "z_top", Gamma(alpha_z, beta_z).expand_by([100]).to_event(1)) w_top = torch.reshape(w_top, [100, 40]) mean_mid = torch.matmul(z_top, w_top) z_mid = pyro.sample("z_mid", Gamma(alpha_z, beta_z / mean_mid).to_event(1)) w_mid = torch.reshape(w_mid, [40, 15]) mean_bottom = torch.matmul(z_mid, w_mid) z_bottom = pyro.sample( "z_bottom", Gamma(alpha_z, beta_z / mean_bottom).to_event(1)) w_bottom = torch.reshape(w_bottom, [15, 4096]) mean_obs = torch.matmul(z_bottom, w_bottom) pyro.sample('obs', Poisson(mean_obs).to_event(1), obs=x)
def model(self, x): x_size = x.size(0) # sample the global weights with pyro.iarange("w_top_iarange", self.top_width * self.mid_width): w_top = pyro.sample("w_top", Gamma(self.alpha_w, self.beta_w)) with pyro.iarange("w_mid_iarange", self.mid_width * self.bottom_width): w_mid = pyro.sample("w_mid", Gamma(self.alpha_w, self.beta_w)) with pyro.iarange("w_bottom_iarange", self.bottom_width * self.image_size): w_bottom = pyro.sample("w_bottom", Gamma(self.alpha_w, self.beta_w)) # sample the local latent random variables # (the iarange encodes the fact that the z's for different datapoints are conditionally independent) with pyro.iarange("data", x_size): z_top = pyro.sample( "z_top", Gamma(self.alpha_z, self.beta_z).expand([self.top_width]).independent(1)) mean_mid = torch.mm(z_top, w_top.reshape(self.top_width, self.mid_width)) z_mid = pyro.sample( "z_mid", Gamma(self.alpha_z, self.beta_z / mean_mid).independent(1)) mean_bottom = torch.mm( z_mid, w_mid.view(self.mid_width, self.bottom_width)) z_bottom = pyro.sample( "z_bottom", Gamma(self.alpha_z, self.beta_z / mean_bottom).independent(1)) mean_obs = torch.mm( z_bottom, w_bottom.view(self.bottom_width, self.image_size)) # observe the data using a poisson likelihood pyro.sample('obs', Poisson(mean_obs).independent(1), obs=x)
def model(data=None, args=None, batch_size=None): if debug: print("model:") data = torch.reshape(data, [64, 1000]) # Globals. with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(1. / 8, 1.)) # shape = [8] + [1024] topic_words = pyro.sample("topic_words", Dirichlet(torch.ones(1024) / 1024)) if debug: print("topic_weights\t: shape={}, sum={}".format( topic_weights.shape, torch.sum(topic_weights))) print("topic_words\t: shape={}".format(topic_words.shape)) # Locals. # with pyro.plate("documents", 1000) as ind: with pyro.plate("documents", 1000, 32, dim=-1) as ind: # if data is not None: # data = data[:, ind] # shape = [64, 32] data = data[:, ind] # shape = [32] + [8] doc_topics = pyro.sample("doc_topics", Dirichlet(topic_weights)) if debug: print("data\t\t: shape={}".format(data.shape)) print("doc_topics\t: shape={}, [0].sum={}".format( doc_topics.shape, torch.sum(doc_topics[0]))) # with pyro.plate("words", 64): with pyro.plate("words", 64, dim=-2): # The word_topics variable is marginalized out during inference, # achieved by specifying infer={"enumerate": "parallel"} and using # TraceEnum_ELBO for inference. Thus we can ignore this variable in # the guide. # shape = [64, 32] + [] word_topics =\ pyro.sample("word_topics", Categorical(doc_topics), infer={"enumerate": "parallel"}) # pyro.sample("word_topics", Categorical(doc_topics)) # shape = [64, 32] + [] data =\ pyro.sample("doc_words", Categorical(topic_words[word_topics]), obs=data) if debug: print("word_topics\t: shape={}".format(word_topics.shape)) print("data\t\t: shape={}".format(data.shape)) return topic_weights, topic_words, data
def main(data, args, batch_size=None): data = torch.reshape(data, [64, 1000]) pyro.module("layer1", layer1) pyro.module("layer2", layer2) pyro.module("layer3", layer3) # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", torch.ones(8), constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", # WL: edited. ===== #torch.ones(8, 1024), #constraint=constraints.greater_than(0.5) torch.ones(8, 1024) * 0.5, constraint=constraints.positive # ================= ) with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.)) # shape = [8] + [1024] # WL: edited. ===== #topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior)) topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior + 0.5)) # ================= # Use an amortized guide for local variables. with pyro.plate("documents", 1000, 32) as ind: # shape = [64, 32] data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(1024, 32) counts = torch.Tensor.scatter_add_\ (counts, 0, data, torch.Tensor.expand(torch.tensor(1.), [1024, 32])) h1 = sigmoid(layer1(torch.transpose(counts, 0, 1))) h2 = sigmoid(layer2(h1)) # shape = [32, 8] doc_topics_w = sigmoid(layer3(h2)) # shape = [32] + [8] doc_topics = softmax(doc_topics_w) pyro.sample("doc_topics", Delta(doc_topics, event_dim=1))
def model(n_samples=None, scale=0.5, invert=False): with pyro.plate('observations', n_samples): thickness = 0.5 + pyro.sample('thickness', Gamma(10., 5.)) if invert: loc = (thickness - 2) * -2 else: loc = (thickness - 2.5) * 2 transforms = ComposeTransform( [SigmoidTransform(), AffineTransform(64, 191)]) intensity = pyro.sample( 'intensity', TransformedDistribution(Normal(loc, scale), transforms)) return thickness, intensity
def _gamma(self): concentration = self.theta rate = self.theta / self.mu # Important remark: Gamma is parametrized by the rate = 1/scale! gamma_d = Gamma(concentration=concentration, rate=rate) return gamma_d
def guide(x): x = torch.reshape(x, [320, 4096]) with pyro.plate("w_top_plate", 4000): #============ sample_ws alpha_w_q =\ pyro.param("log_alpha_w_q_top", alpha_init * torch.ones(4000) + sigma_init * torch.randn(4000)) mean_w_q =\ pyro.param("log_mean_w_q_top", mean_init * torch.ones(4000) + sigma_init * torch.randn(4000)) alpha_w_q = softplus(alpha_w_q) mean_w_q = softplus(mean_w_q) pyro.sample("w_top", Gamma(alpha_w_q, alpha_w_q / mean_w_q)) #============ sample_ws with pyro.plate("w_mid_plate", 600): #============ sample_ws alpha_w_q =\ pyro.param("log_alpha_w_q_mid", alpha_init * torch.ones(600) + sigma_init * torch.randn(600)) mean_w_q =\ pyro.param("log_mean_w_q_mid", mean_init * torch.ones(600) + sigma_init * torch.randn(600)) alpha_w_q = softplus(alpha_w_q) mean_w_q = softplus(mean_w_q) pyro.sample("w_mid", Gamma(alpha_w_q, alpha_w_q / mean_w_q)) #============ sample_ws with pyro.plate("w_bottom_plate", 61440): #============ sample_ws alpha_w_q =\ pyro.param("log_alpha_w_q_bottom", alpha_init * torch.ones(61440) + sigma_init * torch.randn(61440)) mean_w_q =\ pyro.param("log_mean_w_q_bottom", mean_init * torch.ones(61440) + sigma_init * torch.randn(61440)) alpha_w_q = softplus(alpha_w_q) mean_w_q = softplus(mean_w_q) pyro.sample("w_bottom", Gamma(alpha_w_q, alpha_w_q / mean_w_q)) #============ sample_ws with pyro.plate("data", 320): #============ sample_zs alpha_z_q =\ pyro.param("log_alpha_z_q_top", alpha_init * torch.ones(320, 100) + sigma_init * torch.randn(320, 100)) mean_z_q =\ pyro.param("log_mean_z_q_top", mean_init * torch.ones(320, 100) + sigma_init * torch.randn(320, 100)) alpha_z_q = softplus(alpha_z_q) mean_z_q = softplus(mean_z_q) pyro.sample("z_top", Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1)) #============ sample_zs #============ sample_zs alpha_z_q =\ pyro.param("log_alpha_z_q_mid", alpha_init * torch.ones(320, 40) + sigma_init * torch.randn(320, 40)) mean_z_q =\ pyro.param("log_mean_z_q_mid", mean_init * torch.ones(320, 40) + sigma_init * torch.randn(320, 40)) alpha_z_q = softplus(alpha_z_q) mean_z_q = softplus(mean_z_q) pyro.sample("z_mid", Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1)) #============ sample_zs #============ sample_zs alpha_z_q =\ pyro.param("log_alpha_z_q_bottom", alpha_init * torch.ones(320, 15) + sigma_init * torch.randn(320, 15)) mean_z_q =\ pyro.param("log_mean_z_q_bottom", mean_init * torch.ones(320, 15) + sigma_init * torch.randn(320, 15)) alpha_z_q = softplus(alpha_z_q) mean_z_q = softplus(mean_z_q) pyro.sample("z_bottom", Gamma(alpha_z_q, alpha_z_q / mean_z_q).to_event(1))
def guide(data, args, batch_size=None): if debug: print("guide:") data = torch.reshape(data, [64, 1000]) pyro.module("layer1", layer1) pyro.module("layer2", layer2) pyro.module("layer3", layer3) # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param("topic_weights_posterior", torch.ones(8) / 8, constraint=constraints.positive) topic_words_posterior = pyro.param("topic_words_posterior", torch.ones(8, 1024) / 1024, constraint=constraints.positive) with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.)) # shape = [8] + [1024] topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior)) if debug: print("topic_weights\t: shape={}, sum={}".format( topic_weights.shape, torch.sum(topic_weights))) print("topic_words\t: shape={}".format(topic_words.shape)) # Use an amortized guide for local variables. with pyro.plate("documents", 1000, 32, dim=-1) as ind: # shape = [64, 32] data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(1024, 32) counts = torch.Tensor.scatter_add_\ (counts, 0, data, torch.Tensor.expand(torch.tensor(1.), [1024, 32])) if debug: print("counts.shape={}, counts_trans.shape={}".format( counts.shape, torch.transpose(counts, 0, 1).shape)) h1 = sigmoid(layer1(torch.transpose(counts, 0, 1))) h2 = sigmoid(layer2(h1)) # shape = [32, 8] doc_topics = sigmoid(layer3(h2)) if debug: print("doc_topics(nn result)\t: shape={}, [0].sum={}".format( doc_topics.shape, torch.sum(doc_topics[0]))) d = Dirichlet(doc_topics) print("Dirichlet(doc_topics)\t: batch_shape={}, event_shape={}". format(d.batch_shape, d.event_shape)) """ NOTE: There are three options we can take: 1. sample doc_topics (from *DELTA*) *ONLY*. (original code) 2. sample doc_topics (from *DIRICHLET*) *ONLY*. 3. sample doc_topics (from *DIRICHLET*), word_topics (from Categorical) *BOTH*. Expected results: | Trace_ELBO | TraceEnum_ELBO ------------------------------- 1 | FAIL | FAIL 2 | FAIL | PASS 3 | PASS | FAIL """ # shape = [32] + [8] # doc_topics = pyro.sample("doc_topics", Delta(doc_topics, event_dim=1)) # doc_topics = pyro.sample("doc_topics", Delta(doc_topics).to_event(1)) doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics)) if debug: print("doc_topics(sampled)\t: shape={}, [0].sum = {}".format( doc_topics.shape, torch.sum(doc_topics[0]))) with pyro.plate("words", 64, dim=-2): # shape = [64, 32, 8] word_topics_posterior = doc_topics * topic_words.transpose( 0, 1)[data, :] word_topics_posterior = word_topics_posterior / ( word_topics_posterior.sum(dim=-1, keepdim=True)) word_topics =\ pyro.sample("word_topics", Categorical(word_topics_posterior)) if debug: print("word_tpics_posterior\t: shape={}".format( word_topics_posterior.shape)) print("word_topics(sampled)\t: shape={}".format( word_topics.shape)) d = Categorical(word_topics_posterior) print( "Categorical(word_topics_posterior)\t: batch_shape={}, event_shape={}" .format(d.batch_shape, d.event_shape))
def predict(self, x) -> Independent: rate, conc = self(x) event_ndim = len(rate.shape[1:]) # keep only batch dimension return Gamma(rate, conc).to_event(event_ndim)
def test_log_prob(concentration, rate, value): value = torch.tensor(value) log_prob = InverseGamma(concentration, rate).log_prob(value) expected_log_prob = (Gamma(concentration, rate).log_prob(1.0 / value) - 2.0 * value.log()) assert_equal(log_prob, expected_log_prob, prec=1e-6)
def guide(data, args, batch_size=None): if debug: print("guide:") data = torch.reshape(data, [64, 1000]) pyro.module("layer1", layer1) pyro.module("layer2", layer2) pyro.module("layer3", layer3) # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", # lambda: torch.ones(8) / 8, torch.ones(8) / 8, constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", # lambda: torch.ones(8, 1024) / 1024, torch.ones(8, 1024) / 1024, constraint=constraints.positive) """ # wy: dummy param for word_topics word_topics_posterior = pyro.param( "word_topics_posterior", torch.ones(64, 1024, 8) / 8, constraint=constraints.positive) """ with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.)) # shape = [8] + [1024] topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior)) if debug: print("topic_weights\t: shape={}, sum={}".format( topic_weights.shape, torch.sum(topic_weights))) print("topic_words\t: shape={}".format(topic_words.shape)) # Use an amortized guide for local variables. with pyro.plate("documents", 1000, 32) as ind: # shape = [64, 32] data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(1024, 32) counts = torch.Tensor.scatter_add_\ (counts, 0, data, torch.Tensor.expand(torch.tensor(1.), [1024, 32])) if debug: print("counts.shape={}, counts_trans.shape={}".format( counts.shape, torch.transpose(counts, 0, 1).shape)) h1 = sigmoid(layer1(torch.transpose(counts, 0, 1))) h2 = sigmoid(layer2(h1)) # shape = [32, 8] doc_topics_w = sigmoid(layer3(h2)) if debug: print("doc_topics_w(nn result)\t: shape={}, [0].sum={}".format( doc_topics_w.shape, torch.sum(doc_topics_w[0]))) # shape = [32] + [8] # # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w, event_dim=1)) # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w).to_event(1)) doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics_w)) """
def guide(data, args, batch_size=None): if debug: print("guide:") data = torch.reshape(data, [64, 1000]) pyro.module("layer1", layer1) pyro.module("layer2", layer2) pyro.module("layer3", layer3) # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", # lambda: torch.ones(8) / 8, torch.ones(8) / 8, constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", # lambda: torch.ones(8, 1024) / 1024, torch.ones(8, 1024) / 1024, constraint=constraints.positive) """ # wy: dummy param for word_topics word_topics_posterior = pyro.param( "word_topics_posterior", torch.ones(64, 1024, 8) / 8, constraint=constraints.positive) """ with pyro.plate("topics", 8): # shape = [8] + [] topic_weights = pyro.sample("topic_weights", Gamma(topic_weights_posterior, 1.)) # shape = [8] + [1024] topic_words = pyro.sample("topic_words", Dirichlet(topic_words_posterior)) if debug: print("topic_weights\t: shape={}, sum={}". format(topic_weights.shape, torch.sum(topic_weights))) print("topic_words\t: shape={}".format(topic_words.shape)) # Use an amortized guide for local variables. with pyro.plate("documents", 1000, 32) as ind: # shape = [64, 32] data = data[:, ind] # The neural network will operate on histograms rather than word # index vectors, so we'll convert the raw data to a histogram. counts = torch.zeros(1024, 32) counts = torch.Tensor.scatter_add_\ (counts, 0, data, torch.Tensor.expand(torch.tensor(1.), [1024, 32])) if debug: print("counts.shape={}, counts_trans.shape={}". format(counts.shape, torch.transpose(counts, 0, 1).shape)) h1 = sigmoid(layer1(torch.transpose(counts, 0, 1))) h2 = sigmoid(layer2(h1)) # shape = [32, 8] doc_topics_w = sigmoid(layer3(h2)) if debug: print("doc_topics_w(nn result)\t: shape={}, [0].sum={}". format(doc_topics_w.shape, torch.sum(doc_topics_w[0]))) d = Dirichlet(doc_topics_w) print("Dirichlet(doc_topics_w)\t: batch_shape={}, event_shape={}". format(d.batch_shape, d.event_shape)) # shape = [32] + [8] # # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w, event_dim=1)) # doc_topics = pyro.sample("doc_topics", Delta(doc_topics_w).to_event(1)) doc_topics = pyro.sample("doc_topics", Dirichlet(doc_topics_w)) # wy: sample from exact posterior of word_topics with pyro.plate("words", 64): # ks : [K, D] = [8, 32] # ks = torch.arange(0,8).expand(32,8).transpose(0,1) ks = torch.arange(0, 8) ks = torch.Tensor.expand(ks, 32, 8) ks = torch.Tensor.transpose(ks, 0, 1) # logprob1 : [N, D, K] = [32, 8] # logprob1 = Categorical(doc_topics).log_prob(ks).transpose(0,1).expand(64,32,8) logprob1 = Categorical.log_prob(Categorical(doc_topics), ks) logprob1 = torch.Tensor.transpose(logprob1, 0, 1) logprob1 = torch.Tensor.expand(logprob1, 64, 32, 8) # data2 : [N, D, K] = [64, 32, 8] # data2 = data.expand(8,64,32).transpose(0,1).transpose(1,2) data2 = torch.Tensor.expand(data, 8, 64, 32) data2 = torch.Tensor.transpose(data2, 0, 1) data2 = torch.Tensor.transpose(data2, 1, 2) # logprob2 : [N, D, K] = [64, 32, 8] # logprob2 = Categorical(topic_words).log_prob(data2) logprob2 = Categorical.log_prob(Categorical(topic_words), data2) # prob : [N, D, K] = [64, 32, 8] prob = torch.exp(logprob1 + logprob2) # word_topics : [N, D] = [64, 32] word_topics = pyro.sample("word_topics", Categorical(prob)) """