def test_module_nn(nn_module): pyro.clear_param_store() nn_module = nn_module() assert pyro.get_param_store()._params == {} pyro.module("module", nn_module) for name in pyro.get_param_store().get_all_param_names(): assert pyro.params.user_param_name(name) in nn_module.state_dict().keys()
def guide(self, x): # register PyTorch module `encoder` with Pyro pyro.module("encoder", self.encoder) # use the encoder to get the parameters used to define q(z|x) z_mu, z_sigma = self.encoder.forward(x) # sample the latent code z pyro.sample("latent", dist.normal, z_mu, z_sigma)
def guide(): pyro.module("mymodule", pt_guide) mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log) sigma = torch.pow(tau_q, -0.5) pyro.sample("mu_latent", dist.Normal(mu_q, sigma, reparameterized=reparameterized), baseline=dict(use_decaying_avg_baseline=True))
def guide(self, x): # register PyTorch module `encoder` with Pyro pyro.module("encoder", self.encoder) with pyro.iarange("data", x.size(0)): # use the encoder to get the parameters used to define q(z|x) z_loc, z_scale = self.encoder.forward(x) # sample the latent code z pyro.sample("latent", dist.Normal(z_loc, z_scale).independent(1))
def model(self, data, _, **kwargs): pyro.module("decode", self.decode) with pyro.iarange('data', data.size(0), use_cuda=self.use_cuda) as ix: batch = data[ix] n = batch.size(0) (z_where, z_pres), x = self.prior(n, **kwargs) pyro.sample('obs', dist.Normal(x.view(n, -1), (self.likelihood_sd * self.prototype.new_ones(n, self.x_size ** 2))) .independent(1), obs=batch.view(n, -1))
def model(self, x): # register PyTorch module `decoder` with Pyro pyro.module("decoder", self.decoder) # setup hyperparameters for prior p(z) # the type_as ensures we get cuda Tensors if x is on gpu z_mu = ng_zeros([x.size(0), self.z_dim], type_as=x.data) z_sigma = ng_ones([x.size(0), self.z_dim], type_as=x.data) # sample from prior (value will be sampled by guide when computing the ELBO) z = pyro.sample("latent", dist.normal, z_mu, z_sigma) # decode the latent code z mu_img = self.decoder.forward(z) # score against actual images pyro.observe("obs", dist.bernoulli, x.view(-1, 784), mu_img)
def model(self, x): # register PyTorch module `decoder` with Pyro pyro.module("decoder", self.decoder) with pyro.iarange("data", x.size(0)): # setup hyperparameters for prior p(z) z_loc = x.new_zeros(torch.Size((x.size(0), self.z_dim))) z_scale = x.new_ones(torch.Size((x.size(0), self.z_dim))) # sample from prior (value will be sampled by guide when computing the ELBO) z = pyro.sample("latent", dist.Normal(z_loc, z_scale).independent(1)) # decode the latent code z loc_img = self.decoder.forward(z) # score against actual images pyro.sample("obs", dist.Bernoulli(loc_img).independent(1), obs=x.reshape(-1, 784)) # return the loc so we can visualize it later return loc_img
def model(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # register all PyTorch (sub)modules with pyro # this needs to happen in both the model and guide pyro.module("dmm", self) # set z_prev = z_0 to setup the recursive conditioning in p(z_t | z_{t-1}) z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0)) # we enclose all the sample statements in the model in a iarange. # this marks that each datapoint is conditionally independent of the others with pyro.iarange("z_minibatch", len(mini_batch)): # sample the latents z and observed x's one time step at a time for t in range(1, T_max + 1): # the next chunk of code samples z_t ~ p(z_t | z_{t-1}) # note that (both here and elsewhere) we use poutine.scale to take care # of KL annealing. we use the mask() method to deal with raggedness # in the observed data (i.e. different sequences in the mini-batch # have different lengths) # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1}) z_loc, z_scale = self.trans(z_prev) # then sample z_t according to dist.Normal(z_loc, z_scale) # note that we use the reshape method so that the univariate Normal distribution # is treated as a multivariate Normal distribution with a diagonal covariance. with poutine.scale(scale=annealing_factor): z_t = pyro.sample("z_%d" % t, dist.Normal(z_loc, z_scale) .mask(mini_batch_mask[:, t - 1:t]) .independent(1)) # compute the probabilities that parameterize the bernoulli likelihood emission_probs_t = self.emitter(z_t) # the next statement instructs pyro to observe x_t according to the # bernoulli distribution p(x_t|z_t) pyro.sample("obs_x_%d" % t, dist.Bernoulli(emission_probs_t) .mask(mini_batch_mask[:, t - 1:t]) .independent(1), obs=mini_batch[:, t - 1, :]) # the latent sampled at this time step will be conditioned upon # in the next time step so keep track of it z_prev = z_t
def guide(): mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.094 * torch.ones(2), requires_grad=True)) log_sig_q = pyro.param("log_sig_q", Variable( self.analytic_log_sig_n.data - 0.11 * torch.ones(2), requires_grad=True)) sig_q = torch.exp(log_sig_q) trivial_baseline = pyro.module("mu_baseline", pt_mu_baseline, tags="baseline") baseline_value = trivial_baseline(ng_ones(1)) mu_latent = pyro.sample("mu_latent", dist.Normal(mu_q, sig_q, reparameterized=False), baseline=dict(baseline_value=baseline_value)) def obs_inner(i, _i, _x): for k in range(n_superfluous_top + n_superfluous_bottom): z_baseline = pyro.module("z_baseline_%d_%d" % (i, k), pt_superfluous_baselines[3 * k + i], tags="baseline") baseline_value = z_baseline(mu_latent.detach()).unsqueeze(-1) mean_i = pyro.param("mean_%d_%d" % (i, k), Variable(0.5 * torch.ones(4 - i, 1), requires_grad=True)) pyro.sample("z_%d_%d" % (i, k), dist.Normal(mean_i, ng_ones(4 - i, 1), reparameterized=False), baseline=dict(baseline_value=baseline_value)) def obs_outer(i, x): pyro.map_data("map_obs_inner_%d" % i, x, lambda _i, _x: obs_inner(i, _i, _x), batch_size=4 - i) pyro.map_data("map_obs_outer", [self.data_tensor[0:4, :], self.data_tensor[4:7, :], self.data_tensor[7:9, :]], lambda i, x: obs_outer(i, x), batch_size=3) return mu_latent
def model_classify(self, xs, ys=None): """ this model is used to add an auxiliary (supervised) loss as described in the NIPS 2014 paper by Kingma et al titled "Semi-Supervised Learning with Deep Generative Models" """ # register all pytorch (sub)modules with pyro pyro.module("ss_vae", self) # inform Pyro that the variables in the batch of xs, ys are conditionally independent with pyro.iarange("independent"): # this here is the extra Term to yield an auxiliary loss that we do gradient descend on # similar to the NIPS 14 paper (Kingma et al). if ys is not None: alpha = self.encoder_y.forward(xs) pyro.observe("y_aux", dist.categorical, ys, alpha, log_pdf_mask=self.aux_loss_multiplier)
def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # register all PyTorch (sub)modules with pyro pyro.module("dmm", self) # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) # reverse the time-ordering in the hidden state and un-pack it rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0)) # we enclose all the sample statements in the guide in a iarange. # this marks that each datapoint is conditionally independent of the others. with pyro.iarange("z_minibatch", len(mini_batch)): # sample the latents z one time step at a time for t in range(1, T_max + 1): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :]) # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if len(self.iafs) > 0: z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs) else: z_dist = dist.Normal(z_loc, z_scale) assert z_dist.event_shape == () assert z_dist.batch_shape == (len(mini_batch), self.z_q_0.size(0)) # sample z_t from the distribution z_dist with pyro.poutine.scale(scale=annealing_factor): z_t = pyro.sample("z_%d" % t, z_dist.mask(mini_batch_mask[:, t - 1:t]) .independent(1)) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t
def model(self, data): decoder = pyro.module('decoder', self.vae_decoder) z_mean, z_std = torch.zeros([data.size(0), 20]), torch.ones([data.size(0), 20]) with pyro.iarange('data', data.size(0)): z = pyro.sample('latent', Normal(z_mean, z_std).independent(1)) img = decoder.forward(z) pyro.sample('obs', Bernoulli(img).independent(1), obs=data.reshape(-1, 784))
def obs_inner(i, _i, _x): for k in range(n_superfluous_top + n_superfluous_bottom): z_baseline = pyro.module("z_baseline_%d_%d" % (i, k), pt_superfluous_baselines[3 * k + i], tags="baseline") baseline_value = z_baseline(mu_latent.detach()).unsqueeze(-1) mean_i = pyro.param("mean_%d_%d" % (i, k), Variable(0.5 * torch.ones(4 - i, 1), requires_grad=True)) pyro.sample("z_%d_%d" % (i, k), dist.Normal(mean_i, ng_ones(4 - i, 1), reparameterized=False), baseline=dict(baseline_value=baseline_value))
def model(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # register all PyTorch (sub)modules with pyro # this needs to happen in both the model and guide pyro.module("dmm", self) # set z_prev = z_0 to setup the recursive conditioning in p(z_t | z_{t-1}) z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0)) # sample the latents z and observed x's one time step at a time for t in range(1, T_max + 1): # the next three lines of code sample z_t ~ p(z_t | z_{t-1}) # note that (both here and elsewhere) log_pdf_mask takes care of both # (i) KL annealing; and # (ii) raggedness in the observed data (i.e. different sequences # in the mini-batch have different lengths) # first compute the parameters of the diagonal gaussian distribution p(z_t | z_{t-1}) z_mu, z_sigma = self.trans(z_prev) # then sample z_t according to dist.Normal(z_mu, z_sigma) z_t = pyro.sample("z_%d" % t, dist.normal, z_mu, z_sigma, log_pdf_mask=annealing_factor * mini_batch_mask[:, t - 1:t]) # compute the probabilities that parameterize the bernoulli likelihood emission_probs_t = self.emitter(z_t) # the next statement instructs pyro to observe x_t according to the # bernoulli distribution p(x_t|z_t) pyro.observe("obs_x_%d" % t, dist.bernoulli, mini_batch[:, t - 1, :], emission_probs_t, log_pdf_mask=mini_batch_mask[:, t - 1:t]) # the latent sampled at this time step will be conditioned upon # in the next time step so keep track of it z_prev = z_t
def guide(self, x): """ 1. Pass x as the input data to the LSTM encoder 2. Pass the LSTM encoder's last hidden state into the MLP to obtain f(x), g(x) 3. Sample z from multivariate normal N(f(x),g(x)) """ pyro.module("encoder_lstm", self.encoder_lstm) pyro.module("loc_mlp", self.loc_mlp) pyro.module("scale_mlp", self.scale_mlp) batch_size, x_tensor = self._preprocess_input(x) with pyro.plate("data"): encoder_hidden_state = self.encoder_lstm.init_hidden( batch_size=batch_size) for i in range(MAX_STRING_LEN): _, encoder_hidden_state = self.encoder_lstm.forward( x_tensor[i].unsqueeze(0), encoder_hidden_state) encoder_hidden = encoder_hidden_state[0].view(batch_size, -1) encoder_cell = encoder_hidden_state[1].view(batch_size, -1) flattened_lstm_output = torch.cat((encoder_hidden, encoder_cell), dim=1) loc = self.loc_mlp.forward(flattened_lstm_output) scale = self.loc_mlp.forward(flattened_lstm_output) z = pyro.sample("z", dist.Normal(loc, scale).to_event(1))
def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths, annealing_factor=1.0): # this is the number of time steps we need to process in the mini-batch T_max = mini_batch.size(1) # register all PyTorch (sub)modules with pyro pyro.module("dmm", self) # if on gpu we need the fully broadcast view of the rnn initial state # to be in contiguous gpu memory h_0_contig = self.h_0 if not self.use_cuda \ else self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous() # push the observed x's through the rnn; # rnn_output contains the hidden state at each time step rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig) # reverse the time-ordering in the hidden state and un-pack it rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths) # set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...) z_prev = self.z_q_0 # sample the latents z one time step at a time for t in range(1, T_max + 1): # the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T}) z_mu, z_sigma = self.combiner(z_prev, rnn_output[:, t - 1, :]) z_dist = dist.normal # if we are using normalizing flows, we apply the sequence of transformations # parameterized by self.iafs to the base distribution defined in the previous line # to yield a transformed distribution that we use for q(z_t|...) if self.iafs.__len__() > 0: z_dist = TransformedDistribution(z_dist, self.iafs) # sample z_t from the distribution z_dist z_t = pyro.sample("z_%d" % t, z_dist, z_mu, z_sigma, log_pdf_mask=annealing_factor * mini_batch_mask[:, t - 1:t]) # the latent sampled at this time step will be conditioned upon in the next time step # so keep track of it z_prev = z_t
def model(self, response, mask, annealing_factor=1): if self.generative_model == 'irt': irt_model_fn = globals()[f'irt_model_{self.irt_num}pl'] return irt_model_fn( self.ability_dim, response.size(0), self.num_item, response.device, response=response, mask=mask, annealing_factor=annealing_factor, ) else: pyro.module("decoder", self.decoder) ability_prior = dist.Normal( torch.zeros((num_person, ability_dim), device=device), torch.ones((num_person, ability_dim), device=device), ) with poutine.scale(scale=annealing_factor): ability = pyro.sample("ability", ability_prior) item_feat_prior = dist.Normal( torch.zeros((num_item, self.item_feat_dim), device=device), torch.ones((num_item, self.item_feat_dim), device=device), ) item_feat = pyro.sample("item_feat", item_feat_prior) response_mu = self.decoder(ability, item_feat) if mask is not None: response_dist = dist.Bernoulli(response_mu).mask(mask) else: response_dist = dist.Bernoulli(response_mu) if response is not None: pyro.sample("response", response_dist, obs=response) else: response = pyro.sample("response", response_dist) return response, ability, item_feat
def model(self, xs, ys=None): """ The model corresponds to the following generative process: p(z) = normal(0,I) # handwriting style (latent) p(y|x) = categorical(I/10.) # which digit (semi-supervised) p(x|y,z) = bernoulli(mu(y,z)) # an image mu is given by a neural network `decoder` :param xs: a batch of scaled vectors of pixels from an image :param ys: (optional) a batch of the class labels i.e. the digit corresponding to the image(s) :return: None """ # register this pytorch module and all of its sub-modules with pyro pyro.module("ss_vae", self) batch_size = xs.size(0) with pyro.iarange("independent"): # sample the handwriting style from the constant prior distribution prior_mu = Variable(torch.zeros([batch_size, self.z_dim])) prior_sigma = Variable(torch.ones([batch_size, self.z_dim])) zs = pyro.sample("z", dist.normal, prior_mu, prior_sigma) # if the label y (which digit to write) is supervised, sample from the # constant prior, otherwise, observe the value (i.e. score it against the constant prior) alpha_prior = Variable( torch.ones([batch_size, self.output_size]) / (1.0 * self.output_size)) if ys is None: ys = pyro.sample("y", dist.categorical, alpha_prior) else: pyro.observe("y", dist.categorical, ys, alpha_prior) # finally, score the image (x) using the handwriting style (z) and # the class label y (which digit to write) against the # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z)) # where `decoder` is a neural network mu = self.decoder.forward([zs, ys]) pyro.observe("x", dist.bernoulli, xs, mu)
def model(self, xs, ys=None): """ The model corresponds to the following generative process: p(z) = normal(0,I) # handwriting style (latent) p(y|x) = categorical(I/10.) # which digit (semi-supervised) p(x|y,z) = bernoulli(loc(y,z)) # an image loc is given by a neural network `decoder` :param xs: a batch of scaled vectors of pixels from an image :param ys: (optional) a batch of the class labels i.e. the digit corresponding to the image(s) :return: None """ # register this pytorch module and all of its sub-modules with pyro pyro.module("ss_vae", self) batch_size = xs.size(0) options = dict(dtype=xs.dtype, device=xs.device) with pyro.plate("data"): # sample the handwriting style from the constant prior distribution prior_loc = torch.zeros(batch_size, self.z_dim, **options) prior_scale = torch.ones(batch_size, self.z_dim, **options) zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1)) # if the label y (which digit to write) is supervised, sample from the # constant prior, otherwise, observe the value (i.e. score it against the constant prior) alpha_prior = torch.ones(batch_size, self.output_size, ** options) / (1.0 * self.output_size) ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys) # finally, score the image (x) using the handwriting style (z) and # the class label y (which digit to write) against the # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z)) # where `decoder` is a neural network loc = self.decoder.forward([zs, ys]) pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs) # return the loc so we can visualize it later return loc
def model(self, xs, ys=None): # register this pytorch module and all of its sub-modules with pyro pyro.module("generation_net", self) batch_size = xs.shape[0] with pyro.plate("data"): # Prior network uses the baseline predictions as initial guess. # This is the generative process with recurrent connection with torch.no_grad(): # this ensures the training process does not change the # baseline network y_hat = self.baseline_net(xs).view(xs.shape) # sample the handwriting style from the prior distribution, which is # modulated by the input xs. prior_loc, prior_scale = self.prior_net(xs, y_hat) zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1)) # the output y is generated from the distribution pθ(y|x, z) loc = self.generation_net(zs) if ys is not None: # In training, we will only sample in the masked image mask_loc = loc[(xs == -1).view(-1, 784)].view(batch_size, -1) mask_ys = ys[xs == -1].view(batch_size, -1) pyro.sample( "y", dist.Bernoulli(mask_loc, validate_args=False).to_event(1), obs=mask_ys, ) else: # In testing, no need to sample: the output is already a # probability in [0, 1] range, which better represent pixel # values considering grayscale. If we sample, we will force # each pixel to be either 0 or 1, killing the grayscale pyro.deterministic("y", loc.detach()) # return the loc so we can visualize it later return loc
def model(self, imgs=None): """ 1. sample the latent from the prior: 2. runs the generative model 3. score the generative model against actual data """ #-----------------------# #-------- Trick -------# #-----------------------# if (imgs is None): observed = False imgs = torch.zeros(8, self.ch, self.height, self.width) if (self.use_cuda): imgs = imgs.cuda() else: observed = True #-----------------------# #----- Enf of Trick ----# #-----------------------# batch_size, ch, width, height = imgs.shape zero = imgs.new_zeros(1) one = imgs.new_ones(1) pyro.module("decoder", self.decoder) sigma_obs = pyro.param("sigma_obs", 0.1 * one, constraint=constraints.interval(0.01, 0.5)) with pyro.plate('batch_size', batch_size, dim=-1): with poutine.scale(scale=self.scale): z = pyro.sample( 'z_latent', dist.Normal(zero, one).expand([self.dim_z]).to_event(1)) x = self.decoder(z) #x_mu is between 0 and 1 #pyro.sample('obs', dist.Normal(x.x_mu,x.x_std).to_event(1), obs=imgs.view(batch_size,-1)) #pyro.sample('obs', dist.Normal(x.x_mu,sigma_obs).to_event(1), obs=imgs.view(batch_size,-1)) pyro.sample('obs', dist.Normal(x.x_mu, 0.1 * one).to_event(1), obs=imgs.view(batch_size, -1)) return x
def model(self, x, y): pyro.module("decoder", self.decoder) with pyro.plate("data", x.shape[0]): # setup hyperparameters of prior, z : (loc_z, scale_z) loc_z = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device) scale_z = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device) # sample from prior z = pyro.sample("latent", Normal(loc_z, scale_z).to_event(1)) # decoder sampled latent code loc_img = self.decoder(z, y) # sample from Bernoulli with `p=loc_img` pyro.sample("obs", Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784)) return loc_img
def model(self, docs=None, doc_categories=None): # Register PyTorch module `decoder` with Pyro pyro.module("decoder", self.decoder) with pyro.plate("document_categories", doc_categories.shape[0]): # Dirichlet prior 𝑝(𝜃|𝛼) is replaced by a log-normal distribution theta_loc = docs.new_zeros((docs.shape[0], self.num_topics)) theta_scale = docs.new_ones((docs.shape[0], self.num_topics)) theta = pyro.sample( "theta", dist.LogNormal(theta_loc, theta_scale).to_event(1)) theta = theta / theta.sum(-1, keepdim=True) # conditional distribution of 𝑤𝑛 is defined as # 𝑤𝑛|𝛽,𝜃 ~ Categorical(𝜎(𝛽𝜃)) count_param = self.decoder(theta) with pyro.plate("documents", docs.shape[0]): pyro.sample('obs', dist.Multinomial(docs.shape[1], count_param).to_event(1), obs=docs)
def guide(self, x_sent, y_sent, decode=False, kl_annealing=1.0): pyro.module("vnmt", self) #transform sentences into embeddings x_embeds, x_len, x_mask, y_sent = self.x_embed(x_sent, y=y_sent) y_embeds, y_len, y_mask, y_indices = self.y_embed(y_sent, y=range(len(y_sent)), pack_seq=True) #run our sequences through rnn #2nd output is just last hidden state where as I want ALL hidden states which are output x_out, _ = self.encoder(x_embeds) X, x_len = pad_packed_sequence(x_out, batch_first=self.b_f) #TODO one problem with this is that the sequences are varying lengths #TODO i.e. even if the rest of the dims are 0 vectors the strength of #TODO each of existing entries might be lessened depending on how long longest sequence is if self.use_cuda: x_len = x_len.cuda() X = torch.sum(X, dim=1) / x_len.unsqueeze(1).float() y_out, _ = self.encoder(y_embeds) Y, y_len = pad_packed_sequence(y_out, batch_first=self.b_f) if self.use_cuda: y_len = y_len.cuda() Y = torch.sum(Y, dim=1) / y_len.unsqueeze(1).float() Y = zip([r for r in Y], y_indices) Y = sorted(Y, key=lambda x: x[1], reverse=False) Y = torch.stack([y[0] for y in Y]) #Mean pool over the hidden states to z_input = torch.cat([X, Y], dim=1) z_mean, z_sig = self.posterior(z_input) #sample from our variational distribution P(Z | X, Y) with pyro.plate('z_minibatch'): with poutine.scale(scale=kl_annealing): semantics = pyro.sample('z_semantics', dist.Normal(z_mean, z_sig))
def guide(self, x: torch.Tensor, **kwargs: float) -> None: """ Defines the guide q(z,c|x) """ # register PyTorch module `encoder_z` with Pyro pyro.module("encoder_z", self.encoder_z) # KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl) beta = kwargs.get("scale_factor", [1., 1.]) if isinstance(beta, (float, int, list)): beta = torch.tensor(beta) if beta.ndim == 0: beta = torch.tensor([beta, beta]) with pyro.plate("data"): # use the encoder to get the parameters used to define q(z,c|x) z_loc, z_scale, alpha = self.encoder_z(x) # sample the latent code z with pyro.poutine.scale(scale=beta[0]): pyro.sample("latent_cont", dist.Normal(z_loc, z_scale).to_event(1)) with pyro.poutine.scale(scale=beta[1]): pyro.sample("latent_disc", dist.OneHotCategorical(alpha))
def model(self, data): # set up parameters pyro.module('VIN', self) q_prev = self.q2 dq_prev = (self.q2 - self.q1) / self.h self.T = data.shape[1] with pyro.plate('mini_batch', len(data)): pyro.sample('y_0', dist.Normal(loc=self.q1, scale=self.sigma).to_event(1), obs=data[:, 0].unsqueeze(-1)) pyro.sample('y_1', dist.Normal(loc=self.q2, scale=self.sigma).to_event(1), obs=data[:, 1].unsqueeze(-1)) for t in range(2, self.T): q, dq = self.trans(q_prev, dq_prev, self.h) pyro.sample('y_{}'.format(t), dist.Normal(loc=q, scale=self.sigma).to_event(1), obs=data[:, t].unsqueeze(-1)) q_prev = q dq_prev = dq
def guide(self, x, y=None): pyro.module("scanvi", self) with pyro.plate("batch", len(x)), poutine.scale(scale=self.scale_factor): z2_loc, z2_scale, l_loc, l_scale = self.z2l_encoder(x) pyro.sample("l", dist.LogNormal(l_loc, l_scale).to_event(1)) z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1)) y_logits = self.classifier(z2) y_dist = dist.OneHotCategorical(logits=y_logits) if y is None: # x is unlabeled so sample y using q(y|z2) y = pyro.sample("y", y_dist) else: # x is labeled so add a classification loss term # (this way q(y|z2) learns from both labeled and unlabeled data) classification_loss = y_dist.log_prob(y) # Note that the negative sign appears because we're adding this term in the guide # and the guide log_prob appears in the ELBO as -log q pyro.factor("classification_loss", -self.alpha * classification_loss) z1_loc, z1_scale = self.z1_encoder(z2, y) pyro.sample("z1", dist.Normal(z1_loc, z1_scale).to_event(1))
def model(self, x, y): pyro.module(self.name_prefix + ".gp", self) # Draw sample from q(f) function_dist = self.pyro_model(x, name_prefix=self.name_prefix) # Draw samples of cluster assignments cluster_assignment_samples = pyro.sample( self.name_prefix + ".cluster_logits", pyro.distributions.OneHotCategorical(logits=torch.zeros(self.num_tasks, self.num_functions)).to_event( 1 ), ) # Sample from observation distribution with pyro.plate(self.name_prefix + ".output_values_plate", function_dist.batch_shape[-1], dim=-1): function_samples = pyro.sample(self.name_prefix + ".f", function_dist) obs_dist = pyro.distributions.Normal( loc=(function_samples.unsqueeze(-2) * cluster_assignment_samples).sum(-1), scale=self.noise.sqrt() ).to_event(1) with pyro.poutine.scale(scale=(self.num_data / y.size(-2))): return pyro.sample(self.name_prefix + ".y", obs_dist, obs=y)
def model(data): rr = pyro.param("l1", torch.Tensor(5, 10)) p_net_first = pyro.module(pyro_name="p1", Net, my_args) p_net_first = pyro.module(pyro_name="p1", Net(my_args)) p_net_first = pyro.module(pyro_name="p1", net) p_net_second = pyro.module(pyro_name="p2", Net, my_args) p_net_model.behave_different = flip() #pyro.module("dope_name", MNet()) # else: # p_net = pyro.module("dope_other", SuperNet()) # do somethign with rr # model_forward = torch.mm(rr,data) model_forward = p_net.forward(data) # net.paramters() return model_forward
def loss_fn(design, num_particles, evaluation=False, **kwargs): try: pyro.module("h", h) except AssertionError: pass expanded_design = lexpand(design, num_particles) model_conditional_trace = poutine.trace(model_conditional).get_trace(expanded_design) if not evaluation: model_marginal_trace = poutine.trace(model_marginal).get_trace(expanded_design) h_joint = h(expanded_design, model_conditional_trace, observation_labels, target_labels) h_independent = h(expanded_design, model_marginal_trace, observation_labels, target_labels) terms = torch.nn.functional.softplus(-h_joint) + torch.nn.functional.softplus(h_independent) return _safe_mean_terms(terms) else: h_joint = h(expanded_design, model_conditional_trace, observation_labels, target_labels) return _safe_mean_terms(h_joint)
def guide(self, input_variable, target_variable, step): # register PyTorch module `encoder` with Pyro pyro.module("encoder_dense", self.encoder_dense) pyro.module("encoder_rnn", self.encoder_rnn) # init vars input_length = input_variable.shape[0] encoder_outputs = Variable(torch.zeros(input_length, self.encoder_rnn.hidden_size)) encoder_outputs = encoder_outputs.cuda() if USE_CUDA else encoder_outputs encoder_hidden = self.encoder_rnn.init_hidden_gru() # loop to encode for ei in range(input_length): encoder_output, encoder_hidden = self.encoder_rnn( input_variable[ei], encoder_hidden) encoder_outputs[ei] = encoder_output[0][0] # use the encoder to get the parameters used to define q(z|x) z_mu, z_sigma = self.encoder_dense(encoder_hidden) # sample the latent code z pyro.sample("latent", dist.normal, z_mu, z_sigma)
def model(self, raw_expr, encoded_expr, read_depth): pyro.module("decoder", self.decoder) dispersion = pyro.param("dispersion", torch.tensor(5.).to(self.device) * torch.ones(self.num_genes).to(self.device), constraint=constraints.positive) with pyro.plate("cells", encoded_expr.shape[0]): # Dirichlet prior 𝑝(𝜃|𝛼) is replaced by a log-normal distribution theta_loc = self.prior_mu * encoded_expr.new_ones( (encoded_expr.shape[0], self.num_topics)) theta_scale = self.prior_std * encoded_expr.new_ones( (encoded_expr.shape[0], self.num_topics)) theta = pyro.sample( "theta", dist.LogNormal(theta_loc, theta_scale).to_event(1)) theta = theta / theta.sum(-1, keepdim=True) read_scale = pyro.sample( 'read_depth', dist.LogNormal(torch.log(read_depth), 1.).to_event(1)) #read_scale = torch.minimum(read_scale, self.max_scale) # conditional distribution of 𝑤𝑛 is defined as # 𝑤𝑛|𝛽,𝜃 ~ Categorical(𝜎(𝛽𝜃)) expr_rate, dropout = self.decoder(theta) mu = torch.multiply(read_scale, expr_rate) p = torch.minimum(mu / (mu + dispersion), self.max_prob) pyro.sample('obs', dist.ZeroInflatedNegativeBinomial( total_count=dispersion, probs=p, gate_logits=dropout).to_event(1), obs=raw_expr)
def model(self, x: torch.Tensor, y: Optional[torch.Tensor] = None, **kwargs: float) -> None: """ Defines the model p(x|z)p(z) """ # register PyTorch module `decoder` with Pyro pyro.module("decoder", self.decoder) # KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl) beta = kwargs.get("scale_factor", 1.) reshape_ = torch.prod(torch.tensor(x.shape[1:])).item() with pyro.plate("data", x.shape[0]): # setup hyperparameters for prior p(z) z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim))) z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim))) # sample from prior (value will be sampled by guide when computing the ELBO) with pyro.poutine.scale(scale=beta): z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) if self.coord > 0: # rotationally- and/or translationaly-invariant mode # Split latent variable into parts for rotation # and/or translation and image content phi, dx, sc, z = self.split_latent(z) if 't' in self.invariances: dx = (dx * self.t_prior).unsqueeze(1) # transform coordinate grid grid = self.grid.expand(x.shape[0], *self.grid.shape) x_coord_prime = transform_coordinates(grid, phi, dx, sc) # Add class label (if any) if y is not None: z = torch.cat([z, y], dim=-1) # decode the latent code z together with the transformed coordinates (if any) dec_args = (x_coord_prime, z) if self.coord else (z, ) loc = self.decoder(*dec_args) # score against actual images ("binary cross-entropy loss") pyro.sample("obs", self.sampler_d(loc.view(-1, reshape_)).to_event(1), obs=x.view(-1, reshape_))
def model_ncls(self, xs, ys=None): pyro.module("ss_vae", self) batch_size = xs.size(0) options = dict(out=None, dtype=xs.dtype, layout=torch.strided, device=xs.device, requires_grad=False) with pyro.plate("data"): prior_loc = torch.zeros(batch_size, 50, **options) prior_scale = torch.ones(batch_size, 50, **options) zs2 = pyro.sample("z2", dist.Normal(prior_loc, prior_scale).to_event(1)) alpha_prior = torch.ones(batch_size, 4, **options) / 4.0 ys_ = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys) z1_mean, z1_std = self.decoder_z1(zs2, ys_) zs1 = pyro.sample("z1", dist.Normal(z1_mean, z1_std).to_event(1)) x_mean, x_std = self.decoder_x(zs1) pyro.sample('x', dist.Normal(x_mean, x_std).to_event(3), obs=xs)
def guide(xs, ys=None): """ The guide corresponds to the following: q(y|x) = categorical(alpha(x)) # infer digit from an image q(z|x,y) = normal(loc(x,y),scale(x,y)) # infer handwriting style from an image and the digit loc, scale are given by a neural network `encoder_z` alpha is given by a neural network `encoder_y` :param xs: a batch of scaled vectors of pixels from an image :param ys: (optional) a batch of the class labels i.e. the digit corresponding to the image(s) :return: None """ pyro.module("encoder_y_fst", encoder_y_fst) pyro.module("encoder_y_snd", encoder_y_snd) pyro.module("encoder_z_fst", encoder_z_fst) pyro.module("encoder_z_out1", encoder_z_out1) pyro.module("encoder_z_out2", encoder_z_out2) # inform Pyro that the variables in the batch of xs, ys are conditionally independent with pyro.plate("data"): # if the class label (the digit) is not supervised, sample # (and score) the digit with the variational distribution # q(y|x) = categorical(alpha(x)) if ys is None: hidden = softplus(encoder_y_fst(xs)) alpha = softmax(encoder_y_snd(hidden)) ys = pyro.sample("y", OneHotCategorical(alpha)) # sample (and score) the latent handwriting-style with the variational # distribution q(z|x,y) = normal(loc(x,y),scale(x,y)) # shape = broadcast_shape(torch.Size([200]), ys.shape[:-1]) + (-1,) shape = ys.shape[:-1] + (-1, ) hidden_z = softplus( encoder_z_fst(torch.cat( [xs.expand(shape), ys.expand(shape)], -1))) loc = encoder_z_out1(hidden_z) scale = torch.exp(encoder_z_out2(hidden_z)) pyro.sample("z", Normal(loc, scale).to_event(1))
def model(self, input, output): ''' Likelihood model: sample from prior, then decode to video. param input: video of size (batch_size, self.n_frames_input, C, H, W) param output: video of size (batch_size, self.n_frames_output, C, H, W) ''' # Register networks for name, net in self.model_modules.items(): pyro.module(name, net) observation = output corrupted_observation = input # Sample from prior latent = self.sample_latent_prior(input) # Decode decoded_output, masked_decoded_output, components = self.decode(latent) if self.use_crop_size and self.gamma_steps > self.gamma_switch_step and self.is_train: mask_out = Variable(torch.ones_like(decoded_output).cuda()) mask_out[:, :, :, :(self.image_size - self.crop_size[1])] = 0 decoded_output = decoded_output * mask_out if self.gamma_steps == self.gamma_switch_step + 1: print("Loss evaluated in visible frame.") # Observe - prediction decoded_output = decoded_output[:, self.n_frames_input:].view( *observation.size()) sd = Variable(0.3 * torch.ones(*decoded_output.size()).cuda()) pyro.sample('obs', dist.Normal(decoded_output, sd), obs=observation) # Observe - corrupted reconstruction masked_decoded_output = masked_decoded_output.view( *corrupted_observation.size()) sd = Variable(0.3 * torch.ones(*masked_decoded_output.size()).cuda()) pyro.sample('masked_obs', dist.Normal(masked_decoded_output, sd), obs=corrupted_observation)
def _parametrized_guide(self, predictor, data, labels, batch_size=None): args = self.args # Use a conjugate guide for global variables. topic_weights_posterior = pyro.param( "topic_weights_posterior", lambda: torch.ones(args.num_topics), constraint=constraints.positive) topic_words_posterior = pyro.param( "topic_words_posterior", lambda: torch.ones(args.num_topics, args.num_words), constraint=constraints.greater_than(0.5)) # Is this needed? label_posterior = pyro.param("label_posterior", lambda: torch.ones(2, args.num_topics), constraint=constraints.positive) with pyro.plate("topics", args.num_topics): pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.)) pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior)) pyro.sample("label_prior", dist.Beta(*label_posterior)) # Use an amortized guide for local variables. pyro.module("predictor", predictor) with pyro.plate("documents", args.num_docs, batch_size) as ind: data = data[:, ind] labels = labels[ind] counts = (torch.zeros(args.num_words, ind.size(0)).scatter_add( 0, data, torch.ones(data.shape))) augmented_input = torch.zeros(batch_size, args.num_words + args.num_topics) augmented_input[:, :args.num_words] = counts.transpose(0, 1) augmented_input[:, args.num_words:] = labels doc_topics = predictor(augmented_input) pyro.sample("doc_topics", dist.Delta(doc_topics).to_event(1))
def model(self, xs, ys=None): """ The model corresponds to the following generative process: p(z) = normal(0,I) # handwriting style (latent) p(y|x) = categorical(I/10.) # which digit (semi-supervised) p(x|y,z) = bernoulli(mu(y,z)) # an image mu is given by a neural network `decoder` :param xs: a batch of scaled vectors of pixels from an image :param ys: (optional) a batch of the class labels i.e. the digit corresponding to the image(s) :return: None """ # register this pytorch module and all of its sub-modules with pyro pyro.module("ss_vae", self) batch_size = xs.size(0) with pyro.iarange("independent"): # sample the handwriting style from the constant prior distribution prior_mu = Variable(torch.zeros([batch_size, self.z_dim])) prior_sigma = Variable(torch.ones([batch_size, self.z_dim])) zs = pyro.sample("z", dist.normal, prior_mu, prior_sigma) # if the label y (which digit to write) is supervised, sample from the # constant prior, otherwise, observe the value (i.e. score it against the constant prior) alpha_prior = Variable(torch.ones([batch_size, self.output_size]) / (1.0 * self.output_size)) if ys is None: ys = pyro.sample("y", dist.categorical, alpha_prior) else: pyro.observe("y", dist.categorical, ys, alpha_prior) # finally, score the image (x) using the handwriting style (z) and # the class label y (which digit to write) against the # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z)) # where `decoder` is a neural network mu = self.decoder.forward([zs, ys]) pyro.observe("x", dist.bernoulli, xs, mu)
def model(self, xs: torch.Tensor, ys: Optional[torch.Tensor] = None, **kwargs: float) -> None: """ Model of the generative process p(x|z,y)p(y)p(z) """ pyro.module("ss_vae", self) batch_dim = xs.size(0) specs = dict(dtype=xs.dtype, device=xs.device) beta = kwargs.get("scale_factor", 1.) # pyro.plate enforces independence between variables in batches xs, ys with pyro.plate("data"): # sample the latent vector from the constant prior distribution prior_loc = torch.zeros(batch_dim, self.z_dim, **specs) prior_scale = torch.ones(batch_dim, self.z_dim, **specs) with pyro.poutine.scale(scale=beta): zs = pyro.sample( "z", dist.Normal(prior_loc, prior_scale).to_event(1)) # split latent variable into parts for rotation and/or translation # and image content if self.coord > 0: phi, dx, sc, zs = self.split_latent(zs) if 't' in self.invariances: dx = (dx * self.t_prior).unsqueeze(1) # transform coordinate grid grid = self.grid.expand(zs.shape[0], *self.grid.shape) x_coord_prime = transform_coordinates(grid, phi, dx, sc) # sample label from the constant prior or observe the value c_prior = (torch.zeros(batch_dim, self.reg_dim, **specs)) ys = pyro.sample( "y", dist.Normal(c_prior, self.reg_sig).to_event(1), obs=ys) # Score against the parametrized distribution # p(x|y,z) = bernoulli(decoder(y,z)) d_args = (x_coord_prime, [zs, ys]) if self.coord else ([zs, ys],) loc = self.decoder(*d_args) loc = loc.view(*ys.shape[:-1], -1) pyro.sample("x", self.sampler_d(loc).to_event(1), obs=xs)
def model(self, x): pyro.module("decoder_c", self.decoder_c) pyro.module("decoder_y", self.decoder_y) with pyro.plate("data", x.shape[0]): # x is (Outcome, Class, Age, Sex) # prior on U_c mean_c = x.new_zeros(torch.Size((x.shape[0], self.U_c_dim))) std_c = x.new_ones(torch.Size((x.shape[0], self.U_c_dim))) U_c = pyro.sample("U_c", dist.Normal(mean_c, std_c).to_event(1)) # prior on U_y mean_y = x.new_zeros(torch.Size((x.shape[0], self.U_y_dim))) std_y = x.new_ones(torch.Size((x.shape[0], self.U_y_dim))) U_y = pyro.sample("U_y", dist.Normal(mean_y, std_y).to_event(1)) # prior on Age mean_a = 29.7 * x.new_ones(torch.Size((x.shape[0], 1))) std_a = 14.5 * x.new_ones(torch.Size((x.shape[0], 1))) A = pyro.sample("Age", dist.Normal(mean_a, std_a).to_event(1)) # prior on Sex prob_s = 0.6476 * x.new_ones(torch.Size((x.shape[0], 1))) S = pyro.sample("Sex", dist.Bernoulli(prob_s).to_event(1)) # decode the latent code z C_probs = self.decoder_c(U_c, A, S) C = pyro.sample("Class", dist.Multinomial(probs=C_probs).to_event(1), obs=to_one_hot(x[:, 1], self.num_classes)) C = one_hot_to_idx(C) # score against actual outcome Y = self.decoder_y(U_y, A, S, C) pyro.sample("Outcome", dist.Bernoulli(probs=Y).to_event(1), obs=x[:, 0].reshape(-1, 1))
def guide(self, batch, subsample, full_size): num_time_steps, batch_size = batch.shape self.map_estimate("drift") group = self.group(match="state_[0-9]*") cov_diag = pyro.param("state_cov_diag", lambda: torch.full(group.event_shape, 0.01), constraint=constraints.positive) cov_factor = pyro.param( "state_cov_factor", lambda: torch.randn(group.event_shape + (rank, )) * 0.01) if not hasattr(self, "nn"): self.nn = torch.nn.Linear(group.event_shape.numel(), group.event_shape.numel()) self.nn.weight.data.fill_(1.0 / num_time_steps) self.nn.bias.data.fill_(-0.5) pyro.module("state_nn", self.nn) with self.plate("data", full_size, subsample=subsample): loc = self.nn(batch.t()) group.sample( "states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))
def main(observations={"x1": 0, "x2": 0}): pyro.module("first", first) pyro.module("second", second) pyro.module("third", third) pyro.module("fourth", fourth) pyro.module("fifth", fifth) obs = torch.tensor([float(observations["x1"]), float(observations["x2"])]) x1 = obs[0] x2 = obs[1] v = torch.cat((torch.Tensor.view(x1, [1, 1]), torch.Tensor.view(x2, [1, 1])), 1) h1 = relu(first(v)) h2 = relu(second(h1)) h3 = relu(third(h2)) h4 = relu(fourth(h3)) out = fifth(h4) mean = out[0, 0] std = torch.exp(out[0, 1]) pyro.sample("z", Normal(mean, std))
def model(self, x): # register PyTorch module `decoder` with Pyro pyro.module("decoder", self.decoder) with pyro.plate("data", x.shape[0]): # setup hyperparameters for prior p(z) z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device) z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device) # sample from prior (value will be sampled by guide when computing the ELBO) z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) # decode the latent code z loc_img = self.decoder.forward(z) # score against actual images pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784)) # return the loc so we can visualize it later return loc_img
def guide(self, max_time_step): pyro.module("agentmodel", self) observation = reset_env() for t in range(MAXTIME): state = observation observation = torch.from_numpy(observation).float() prob_action = self.policy(observation) action = pyro.sample("action_{}".format(t), dist.Bernoulli(prob_action)) action = round(action.item()) observation, reward, done, _ = env.step(action) if done and self.echo: print("guide exit at", t) return t break if (done): break if self.echo: print("guide solve the problem at t:", max_time_step) return max_time_step
def test_save_and_load(self): lin = pyro.module("mymodule", self.linear_module) pyro.module("mymodule2", self.linear_module2) x = torch.randn(1, 3) myparam = pyro.param("myparam", torch.tensor(1.234 * torch.ones(1), requires_grad=True)) cost = torch.sum(torch.pow(lin(x), 2.0)) * torch.pow(myparam, 4.0) cost.backward() params = list(self.linear_module.parameters()) + [myparam] optim = torch.optim.Adam(params, lr=.01) myparam_copy_stale = copy(pyro.param("myparam").detach().cpu().numpy()) optim.step() myparam_copy = copy(pyro.param("myparam").detach().cpu().numpy()) param_store_params = copy(pyro.get_param_store()._params) param_store_param_to_name = copy(pyro.get_param_store()._param_to_name) assert len(list(param_store_params.keys())) == 5 assert len(list(param_store_param_to_name.values())) == 5 pyro.get_param_store().save('paramstore.unittest.out') pyro.clear_param_store() assert len(list(pyro.get_param_store()._params)) == 0 assert len(list(pyro.get_param_store()._param_to_name)) == 0 pyro.get_param_store().load('paramstore.unittest.out') def modules_are_equal(): weights_equal = np.sum(np.fabs(self.linear_module3.weight.detach().cpu().numpy() - self.linear_module.weight.detach().cpu().numpy())) == 0.0 bias_equal = np.sum(np.fabs(self.linear_module3.bias.detach().cpu().numpy() - self.linear_module.bias.detach().cpu().numpy())) == 0.0 return (weights_equal and bias_equal) assert not modules_are_equal() pyro.module("mymodule", self.linear_module3, update_module_params=False) assert id(self.linear_module3.weight) != id(pyro.param('mymodule$$$weight')) assert not modules_are_equal() pyro.module("mymodule", self.linear_module3, update_module_params=True) assert id(self.linear_module3.weight) == id(pyro.param('mymodule$$$weight')) assert modules_are_equal() myparam = pyro.param("myparam") store = pyro.get_param_store() assert myparam_copy_stale != myparam.detach().cpu().numpy() assert myparam_copy == myparam.detach().cpu().numpy() assert sorted(param_store_params.keys()) == sorted(store._params.keys()) assert sorted(param_store_param_to_name.values()) == sorted(store._param_to_name.values()) assert sorted(store._params.keys()) == sorted(store._param_to_name.values())
def prior(self, n, **kwargs): pyro.module("decode", self.decode) return self.local_model(n, **kwargs)
def guide(self, data): encoder = pyro.module('encoder', self.vae_encoder) with pyro.iarange('data', data.size(0)): z_mean, z_var = encoder.forward(data) pyro.sample('latent', Normal(z_mean, z_var.sqrt()).independent(1))
def model(self, data, batch_size, **kwargs): pyro.module("decode", self.decode) with pyro.iarange('data', data.size(0), use_cuda=self.use_cuda) as ix: return self.local_model(batch_size, data[ix], **kwargs)
def guide(self, data, batch_size, **kwargs): pyro.module('rnn', self.rnn), pyro.module('predict', self.predict), pyro.module('encode', self.encode), pyro.module('embed', self.embed), pyro.module('bl_rnn', self.bl_rnn, tags='baseline'), pyro.module('bl_predict', self.bl_predict, tags='baseline'), pyro.module('bl_embed', self.bl_embed, tags='baseline') pyro.param('h_init', self.h_init) pyro.param('c_init', self.c_init) pyro.param('z_where_init', self.z_where_init) pyro.param('z_what_init', self.z_what_init) pyro.param('bl_h_init', self.bl_h_init, tags='baseline') pyro.param('bl_c_init', self.bl_c_init, tags='baseline') with pyro.iarange('data', data.size(0), subsample_size=batch_size, use_cuda=self.use_cuda) as ix: batch = data[ix] return self.local_guide(batch.size(0), batch)
def cnn_fn(x): return pyro.module("CNN", cnn)(x)
def guide(): pyro.module("mymodule", pt_guide) mu_q, tau_q = torch.exp(pt_guide.mu_q_log), torch.exp(pt_guide.tau_q_log) sigma = torch.pow(tau_q, -0.5) pyro.sample("mu_latent", dist.Normal(mu_q, sigma, reparameterized=reparameterized))
def do_elbo_test(self, repa1, repa2, n_steps, prec, lr, use_nn_baseline, use_decaying_avg_baseline): if self.verbose: print(" - - - - - DO NORMALNORMALNORMAL ELBO TEST - - - - - -") print("[reparameterized = %s, %s; nn_baseline = %s, decaying_baseline = %s]" % (repa1, repa2, use_nn_baseline, use_decaying_avg_baseline)) pyro.clear_param_store() if use_nn_baseline: class VanillaBaselineNN(nn.Module): def __init__(self, dim_input, dim_h): super(VanillaBaselineNN, self).__init__() self.lin1 = nn.Linear(dim_input, dim_h) self.lin2 = nn.Linear(dim_h, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): h = self.sigmoid(self.lin1(x)) return self.lin2(h) mu_prime_baseline = pyro.module("mu_prime_baseline", VanillaBaselineNN(2, 5), tags="baseline") else: mu_prime_baseline = None def model(): mu_latent_prime = pyro.sample( "mu_latent_prime", dist.Normal(self.mu0, torch.pow(self.lam0, -0.5), reparameterized=repa1)) mu_latent = pyro.sample( "mu_latent", dist.Normal(mu_latent_prime, torch.pow(self.lam0, -0.5), reparameterized=repa2)) for i, x in enumerate(self.data): pyro.observe("obs_%d" % i, dist.normal, x, mu_latent, torch.pow(self.lam, -0.5)) return mu_latent # note that the exact posterior is not mean field! def guide(): mu_q = pyro.param("mu_q", Variable(self.analytic_mu_n.data + 0.334 * torch.ones(2), requires_grad=True)) log_sig_q = pyro.param("log_sig_q", Variable( self.analytic_log_sig_n.data - 0.29 * torch.ones(2), requires_grad=True)) mu_q_prime = pyro.param("mu_q_prime", Variable(torch.Tensor([-0.34, 0.52]), requires_grad=True)) kappa_q = pyro.param("kappa_q", Variable(torch.Tensor([0.74]), requires_grad=True)) log_sig_q_prime = pyro.param("log_sig_q_prime", Variable(-0.5 * torch.log(1.2 * self.lam0.data), requires_grad=True)) sig_q, sig_q_prime = torch.exp(log_sig_q), torch.exp(log_sig_q_prime) mu_latent_dist = dist.Normal(mu_q, sig_q, reparameterized=repa2) mu_latent = pyro.sample("mu_latent", mu_latent_dist, baseline=dict(use_decaying_avg_baseline=use_decaying_avg_baseline)) mu_latent_prime_dist = dist.Normal(kappa_q.expand_as(mu_latent) * mu_latent + mu_q_prime, sig_q_prime, reparameterized=repa1) pyro.sample("mu_latent_prime", mu_latent_prime_dist, baseline=dict(nn_baseline=mu_prime_baseline, nn_baseline_input=mu_latent, use_decaying_avg_baseline=use_decaying_avg_baseline)) return mu_latent # optim = Optimize(model, guide, # torch.optim.Adam, {"lr": lr, "betas": (0.97, 0.999)}, # loss="ELBO", trace_graph=True, # auxiliary_optim_constructor=torch.optim.Adam, # auxiliary_optim_args={"lr": 5.0 * lr, "betas": (0.90, 0.999)}) adam = optim.Adam({"lr": .0015, "betas": (0.97, 0.999)}) svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True) for k in range(n_steps): svi.step() mu_error = param_mse("mu_q", self.analytic_mu_n) log_sig_error = param_mse("log_sig_q", self.analytic_log_sig_n) mu_prime_error = param_mse("mu_q_prime", 0.5 * self.mu0) kappa_error = param_mse("kappa_q", 0.5 * ng_ones(1)) log_sig_prime_error = param_mse("log_sig_q_prime", -0.5 * torch.log(2.0 * self.lam0)) if k % 500 == 0 and self.verbose: print("errors: %.4f, %.4f" % (mu_error, log_sig_error), end='') print(", %.4f, %.4f" % (mu_prime_error, log_sig_prime_error), end='') print(", %.4f" % kappa_error) self.assertEqual(0.0, mu_error, prec=prec) self.assertEqual(0.0, log_sig_error, prec=prec) self.assertEqual(0.0, mu_prime_error, prec=prec) self.assertEqual(0.0, log_sig_prime_error, prec=prec) self.assertEqual(0.0, kappa_error, prec=prec)
def guide(self, data): encoder = pyro.module('encoder', self.vae_encoder) z_mean, z_var = encoder.forward(data) pyro.sample('latent', Normal(z_mean, z_var.sqrt()))
def guide(self, data, batch_size, **kwargs): pyro.module('rnn', self.rnn), pyro.module('predict', self.predict), pyro.module('encode', self.encode), pyro.module('embed', self.embed), pyro.module('bl_rnn', self.bl_rnn), pyro.module('bl_predict', self.bl_predict), pyro.module('bl_embed', self.bl_embed) pyro.param('h_init', self.h_init) pyro.param('c_init', self.c_init) pyro.param('z_where_init', self.z_where_init) pyro.param('z_what_init', self.z_what_init) pyro.param('bl_h_init', self.bl_h_init) pyro.param('bl_c_init', self.bl_c_init) with pyro.iarange('data', data.size(0), subsample_size=batch_size, use_cuda=self.use_cuda) as ix: batch = data[ix] n = batch.size(0) # Embed inputs. flattened_batch = batch.view(n, -1) inputs = { 'raw': batch, 'embed': self.embed(flattened_batch), 'bl_embed': self.bl_embed(flattened_batch) } # Initial state. state = GuideState( h=batch_expand(self.h_init, n), c=batch_expand(self.c_init, n), bl_h=batch_expand(self.bl_h_init, n), bl_c=batch_expand(self.bl_c_init, n), z_pres=self.prototype.new_ones(n, self.z_pres_size), z_where=batch_expand(self.z_where_init, n), z_what=batch_expand(self.z_what_init, n)) z_pres = [] z_where = [] for t in range(self.num_steps): state = self.guide_step(t, n, state, inputs) z_where.append(state.z_where) z_pres.append(state.z_pres) return z_where, z_pres
def model(self, data): decoder = pyro.module('decoder', self.vae_decoder) z_mean, z_std = ng_zeros([data.size(0), 20]), ng_ones([data.size(0), 20]) z = pyro.sample('latent', Normal(z_mean, z_std)) img = decoder.forward(z) pyro.observe('obs', Bernoulli(img), data.view(-1, 784))