def model(): p2 = torch.tensor(torch.ones(2) / 2) p3 = torch.tensor(torch.ones(3) / 3) x2 = pyro.sample("x2", dist.OneHotCategorical(p2)) x3 = pyro.sample("x3", dist.OneHotCategorical(p3)) assert x2.shape == torch.Size([2]) + iarange_shape + p2.shape assert x3.shape == torch.Size([3, 1]) + iarange_shape + p3.shape
def model(self, xs, ys=None, aux_scale=46): pyro.module("ss_vae", self) batch_size = xs.size(0) options = dict(out=None, dtype=xs.dtype, layout=torch.strided, device=xs.device, requires_grad=False) with pyro.plate("data"): prior_loc = torch.zeros(batch_size, 50, **options) prior_scale = torch.ones(batch_size, 50, **options) zs2 = pyro.sample("z2", dist.Normal(prior_loc, prior_scale).to_event(1)) alpha_prior = torch.ones(batch_size, 4, **options) / 4.0 ys_ = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys) z1_mean, z1_std = self.decoder_z1(zs2, ys_) zs1 = pyro.sample("z1", dist.Normal(z1_mean, z1_std).to_event(1)) x_mean, x_std = self.decoder_x(zs1) pyro.sample('x', dist.Normal(x_mean, x_std).to_event(3), obs=xs) if ys is not None: alpha = self.encoder_y(zs1) # with pyro.poutine.scale(scale=self.aux_scale): with pyro.poutine.scale(scale=aux_scale): pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
def test_batch_log_dims(dim, probs): batch_pdf_shape = (3,) + (1,) * (dim-1) expected_log_prob_sum = np.array(wrap_nested(list(np.log(probs)), dim-1)).reshape(*batch_pdf_shape) probs = modify_params_using_dims(probs, dim) support = dist.OneHotCategorical(probs).enumerate_support() log_prob = dist.OneHotCategorical(probs).log_prob(support) assert_equal(log_prob.detach().cpu().numpy(), expected_log_prob_sum)
def generate_name(self, lstm: Decoder, address: str, batch_size: int, hidd_cell_states: tuple = None, sample: bool = True): """ lstm: Decoder associated with name being generated address: The address to correlate pyro distribution with latent variables hidd_cell_states: Previous LSTM hidden state or empty hidden state max_name_length: The max name length allowed """ # If no hidden state is provided, initialize it with all 0s if hidd_cell_states == None: hidd_cell_states = lstm.init_hidden(batch_size=batch_size) input_tensor = strings_to_tensor([SOS] * batch_size, 1, letter_to_index) names = [''] * batch_size for index in range(MAX_NAME_LENGTH): char_dist, hidd_cell_states = lstm.forward(input_tensor, hidd_cell_states) if sample: # Next LSTM input is the sampled character input_tensor = pyro.sample(f"unsup_{address}_{index}", dist.OneHotCategorical(char_dist)) chars_at_indexes = list( map(lambda index: MODEL_CHARS[int(index.item())], torch.argmax(input_tensor, dim=2).squeeze(0))) else: # Next LSTM input is the character with the highest probability of occurring pyro.sample(f"unsup_{address}_{index}", dist.OneHotCategorical(char_dist)) chars_at_indexes = list( map(lambda index: MODEL_CHARS[int(index.item())], torch.argmax(char_dist, dim=2).squeeze(0))) input_tensor = strings_to_tensor(chars_at_indexes, 1, letter_to_index) # Add sampled characters to names for i, char in enumerate(chars_at_indexes): names[i] += char # Discard everything after EOS character # names = list(map(lambda name: name[:name.find(EOS)] if name.find(EOS) > -1 else name, names)) return hidd_cell_states, names
def model_classify(self, xs, lengths, ys=None): """ this model is used to add an auxiliary (supervised) loss as described in the Kingma et al., "Semi-Supervised Learning with Deep Generative Models". """ xs = pad_sequence(xs, batch_first=True, padding_value=self.padding_idx) if ys is not None: ys = torch.stack(ys) if self.use_cuda: xs = xs.cuda() lengths = lengths.cuda() if ys is not None: ys = ys.cuda() #Get embeddings xs = self.embeddings(xs) # register all pytorch (sub)modules with pyro pyro.module("text_ss_vae", self) # inform Pyro that the variables in the batch of xs, ys are conditionally independent with pyro.plate("data"): # this here is the extra term to yield an auxiliary loss that we do gradient descent on if ys is not None: alpha = self.encoder_y.forward(xs, lengths) with pyro.poutine.scale(scale=self.aux_loss_multiplier): pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
def one_hot_model(pseudocounts, classes=None): probs_prior = dist.Dirichlet(pseudocounts) probs = pyro.sample("probs", probs_prior) with pyro.plate("classes", classes.size(0) if classes is not None else 1, dim=-1): return pyro.sample("obs", dist.OneHotCategorical(probs), obs=classes)
def model(self, data, label=None): pyro.module("decoder", self) batch_size = data.size(0) with pyro.plate("data"): deformation_loc = torch.zeros([batch_size, 1]) deformation_scale = torch.ones([batch_size, 1]) label_prior = data.new_ones([batch_size, self.num_classes ]) / (1.0 * self.num_classes) deformation = pyro.sample( "deformation", dist.Normal(deformation_loc, deformation_scale).to_event(1)) label = pyro.sample("label", dist.OneHotCategorical(label_prior), obs=label) final_image = self.decoder(deformation, label, data) pyro.sample("image", dist.Bernoulli(final_image).to_event(1), obs=data)
def test_one_hot_categorical_shape(): ps = ng_ones(3, 2) / 2 d = dist.OneHotCategorical(ps) assert d.batch_shape() == (3, ) assert d.event_shape() == (2, ) assert d.shape() == (3, 2) assert d.sample().size() == d.shape()
def model(self, x, y=None): # Register various nn.Modules with Pyro pyro.module("scanvi", self) # This gene-level parameter modulates the variance of the observation distribution theta = pyro.param("inverse_dispersion", 10.0 * x.new_ones(self.num_genes), constraint=constraints.positive) # We scale all sample statements by scale_factor so that the ELBO is normalized # wrt the number of datapoints and genes with pyro.plate("batch", len(x)), poutine.scale(scale=self.scale_factor): z1 = pyro.sample("z1", dist.Normal(0, x.new_ones(self.latent_dim)).to_event(1)) # Note that if y is None (i.e. y is unobserved) then y will be sampled; # otherwise y will be treated as observed. y = pyro.sample("y", dist.OneHotCategorical(logits=x.new_zeros(self.num_labels)), obs=y) z2_loc, z2_scale = self.z2_decoder(z1, y) z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1)) l_scale = self.l_scale * x.new_ones(1) l = pyro.sample("l", dist.LogNormal(self.l_loc, l_scale).to_event(1)) # Note that by construction mu is normalized (i.e. mu.sum(-1) == 1) and the # total scale of counts for each cell is determined by `l` gate_logits, mu = self.x_decoder(z2) # TODO revisit this parameterization if torch.distributions.NegativeBinomial changes # from failure to success parametrization; # see https://github.com/pytorch/pytorch/issues/42449 nb_logits = (l * mu + self.epsilon).log() - (theta + self.epsilon).log() x_dist = dist.ZeroInflatedNegativeBinomial(gate_logits=gate_logits, total_count=theta, logits=nb_logits) # Observe the datapoint x using the observation distribution x_dist pyro.sample("x", x_dist.to_event(1), obs=x)
def guide(self, x, y=None): pyro.module("scanvi", self) with pyro.plate("batch", len(x)), poutine.scale(scale=self.scale_factor): z2_loc, z2_scale, l_loc, l_scale = self.z2l_encoder(x) pyro.sample("l", dist.LogNormal(l_loc, l_scale).to_event(1)) z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1)) y_logits = self.classifier(z2) y_dist = dist.OneHotCategorical(logits=y_logits) if y is None: # x is unlabeled so sample y using q(y|z2) y = pyro.sample("y", y_dist) else: # x is labeled so add a classification loss term # (this way q(y|z2) learns from both labeled and unlabeled data) classification_loss = y_dist.log_prob(y) # Note that the negative sign appears because we're adding this term in the guide # and the guide log_prob appears in the ELBO as -log q pyro.factor( "classification_loss", -self.alpha * classification_loss, has_rsample=False, ) z1_loc, z1_scale = self.z1_encoder(z2, y) pyro.sample("z1", dist.Normal(z1_loc, z1_scale).to_event(1))
def model(self, xs, ys): # register this pytorch module and all of its sub-modules with pyro pyro.module("ss_vae", self) batch_size = xs.size(0) # inform Pyro that the variables in the batch of xs, ys are conditionally independent with pyro.plate("data"): # sample the handwriting style from the constant prior distribution prior_loc = xs.new_zeros([batch_size, self.z_dim]) prior_scale = xs.new_ones([batch_size, self.z_dim]) zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1)) # if the label y (which digit to write) is supervised, sample from the # constant prior, otherwise, observe the value (i.e. score it against the constant prior) alpha_prior = xs.new_ones([batch_size, self.output_size ]) / (1.0 * self.output_size) ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys) # finally, score the image (x) using the handwriting style (z) and # the class label y (which digit to write) against the # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z)) # where `decoder` is a neural network loc = self.decoder.forward([zs, ys]) pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs)
def guide(self, xs, ys=None): """ The guide corresponds to the following: q(y|x) = categorical(alpha(x)) # infer digit from an image q(z|x,y) = normal(loc(x,y),scale(x,y)) # infer handwriting style from an image and the digit loc, scale are given by a neural network `encoder_z` alpha is given by a neural network `encoder_y` :param xs: a batch of scaled vectors of pixels from an image :param ys: (optional) a batch of the class labels i.e. the digit corresponding to the image(s) :return: None """ # inform Pyro that the variables in the batch of xs, ys are conditionally independent with pyro.plate("data"): # if the class label (the digit) is not supervised, sample # (and score) the digit with the variational distribution # q(y|x) = categorical(alpha(x)) if ys is None: alpha = self.encoder_y.forward(xs) ys = pyro.sample("y", dist.OneHotCategorical(alpha)) # sample (and score) the latent handwriting-style with the variational # distribution q(z|x,y) = normal(loc(x,y),scale(x,y)) loc, scale = self.encoder_z.forward([xs, ys]) pyro.sample("z", dist.Normal(loc, scale).to_event(1))
def model_classify(self, xs, ys=None): pyro.module("ss_vae", self) with pyro.plate("data"): if ys is not None: z1, _ = self.encoder_z1(xs) alpha = self.encoder_y(z1) pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
def gmm_batch_guide(data): with pyro.iarange("data", len(data)) as batch: n = len(batch) probs = pyro.param("probs", torch.tensor(torch.ones(n, 1) * 0.6, requires_grad=True)) probs = torch.cat([probs, 1 - probs], dim=1) z = pyro.sample("z", dist.OneHotCategorical(probs)) assert z.shape[-2:] == (n, 2)
def model(self, xs, ys=None): # can pass in ys as labels or not pass in anything for unlabelled # register PyTorch module `decoder` with Pyro pyro.module("decoder", self.decoder) # with pyro.plate("data"): with pyro.iarange("data", xs.shape[0]): # setup hyperparameters for prior p(z) z_loc = xs.new_zeros(torch.Size((xs.shape[0], self.z_dim))) z_scale = xs.new_ones(torch.Size((xs.shape[0], self.z_dim))) # sample from prior (value will be sampled by guide when computing the ELBO) zs = pyro.sample("z", dist.Normal(z_loc, z_scale).independent(1)) # if there is a label y, sample from the constant prior # otherwise, observe the value (i.e. score against constant prior) # alpha_prior = xs.new_ones(torch.Size((xs.shape[0], self.output_size))) / (1.0*self.output_size) # prior is the actual train data label distribution (not uniform) alpha_prior = self.prior_probs.repeat(xs.shape[0], 1) ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys) # decoder outputs mean and sqroot cov, sample from normal recon_loc, recon_scale = self.decoder.forward([zs, ys]) pyro.sample("x", dist.Normal(recon_loc, recon_scale).independent(1), obs=xs.reshape(-1, 1335))
def model(self, X_u: list, X_s: list, Z_s: dict, observations=None): """ Model for generating names representing p(x,z) x: Training data (name string) z: Optionally supervised latent values (dictionary of name/format values) """ pyro.module("model_fn_lstm", self.model_fn_lstm) formatted_X_u = strings_to_tensor(X_u, MAX_NAME_LENGTH, printable_to_index) formatted_X_s = strings_to_tensor(X_s, MAX_NAME_LENGTH, printable_to_index) with pyro.plate("sup_batch", len(X_s)): _, first_names = self.generate_name_supervised( self.model_fn_lstm, FIRST_NAME_ADD, len(X_s), observed=Z_s[FIRST_NAME_ADD]) full_names = list( map(lambda name: pad_string(name, MAX_NAME_LENGTH), first_names)) probs = strings_to_probs(full_names, MAX_NAME_LENGTH, printable_to_index, true_index_prob=self.peak_prob) pyro.sample("sup_output", dist.OneHotCategorical(probs.transpose(0, 1)).to_event(1), obs=formatted_X_s.transpose(0, 1)) with pyro.plate("unsup_batch", len(X_u)): _, first_names = self.generate_name(self.model_fn_lstm, FIRST_NAME_ADD, len(X_u)) full_names = list( map(lambda name: pad_string(name, MAX_NAME_LENGTH), first_names)) probs = strings_to_probs(full_names, MAX_NAME_LENGTH, printable_to_index, true_index_prob=self.peak_prob) pyro.sample("unsup_output", dist.OneHotCategorical(probs.transpose(0, 1)).to_event(1), obs=formatted_X_u.transpose(0, 1)) return full_names
def remap_y(self, ys): new_ys = [] options = dict(dtype=ys.dtype, device=ys.device) for i, label_length in enumerate(self.label_shape): prior = torch.ones(ys.size(0), label_length, **options) / (1.0 * label_length) new_ys.append(pyro.sample("y_%s" % self.label_names[i], dist.OneHotCategorical(prior), obs=torch.nn.functional.one_hot(ys[:,i].to(torch.int64), int(label_length)))) new_ys = torch.cat(new_ys, -1) return new_ys.to(torch.float32)
def model(self, images, labels=None, kl_factor=1.0): images = images.view(-1, 784) n_images = images.size(0) # Set-up parameters for the distribution of weights for each layer `a<n>` a1_mean = torch.zeros(784, self.n_hidden) a1_scale = torch.ones(784, self.n_hidden) a1_dropout = torch.tensor(0.25) a2_mean = torch.zeros(self.n_hidden + 1, self.n_classes) a2_scale = torch.ones(self.n_hidden + 1, self.n_hidden) a2_dropout = torch.tensor(1.0) a3_mean = torch.zeros(self.n_hidden + 1, self.n_classes) a3_scale = torch.ones(self.n_hidden + 1, self.n_hidden) a3_dropout = torch.tensor(1.0) a4_mean = torch.zeros(self.n_hidden + 1, self.n_classes) a4_scale = torch.ones(self.n_hidden + 1, self.n_classes) # Mark batched calculations to be conditionally independent given parameters using `plate` with pyro.plate('data', size=n_images): # Sample first hidden layer h1 = pyro.sample( 'h1', bnn.HiddenLayer(images, a1_mean, a1_dropout * a1_scale, non_linearity=nnf.leaky_relu, KL_factor=kl_factor)) # Sample second hidden layer h2 = pyro.sample( 'h2', bnn.HiddenLayer(h1, a2_mean, a2_dropout * a2_scale, non_linearity=nnf.leaky_relu, KL_factor=kl_factor)) # Sample third hidden layer h3 = pyro.sample( 'h3', bnn.HiddenLayer(h2, a3_mean, a3_dropout * a3_scale, non_linearity=nnf.leaky_relu, KL_factor=kl_factor)) # Sample output logits logits = pyro.sample( 'logits', bnn.HiddenLayer( h3, a4_mean, a4_scale, non_linearity=lambda x: nnf.log_softmax(x, dim=-1), KL_factor=kl_factor, include_hidden_bias=False)) # One-hot encode labels labels = nnf.one_hot(labels) if labels is not None else None # Condition on observed labels, so it calculates the log-likehood loss when training using VI return pyro.sample('label', dist.OneHotCategorical(logits=logits), obs=labels)
def model_classify(self, xs, ys=None): # register PyTorch module `encoder` with Pyro pyro.module("encoder_y", self.encoder_y) # with pyro.plate("data") with pyro.iarange("data", xs.shape[0]): # this here is the extra term to yield an auxiliary loss that we do gradient descent on if ys is not None: alpha = self.encoder_y.forward(xs) with pyro.poutine.scale(scale=self.aux_loss_multiplier): pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
def gmm_batch_model(data): p = pyro.param("p", torch.tensor([0.3], requires_grad=True)) p = torch.cat([p, 1 - p]) scale = pyro.param("scale", torch.tensor([1.0], requires_grad=True)) mus = torch.tensor([-1.0, 1.0]) with pyro.iarange("data", len(data)) as batch: n = len(batch) z = pyro.sample("z", dist.OneHotCategorical(p).expand_by([n])) assert z.shape[-2:] == (n, 2) loc = (z * mus).sum(-1) pyro.sample("x", dist.Normal(loc, scale.expand(n)), obs=data[batch])
def test_basevae_encode_xy(): data_dim = (2, 64) x = torch.randn(*data_dim) alpha = torch.ones(data_dim[0], 3) / 3 y = dist.OneHotCategorical(alpha).sample() vae = models.base.baseVAE(data_dim[1:], None) encoder_net = nets.fcEncoderNet(data_dim[1:], 2, 3) vae.set_encoder(encoder_net) encoded = vae._encode(x, y) assert_equal(encoded[:, :2].shape, (data_dim[0], 2)) assert_equal(encoded[:, 2:].shape, (data_dim[0], 2))
def test_posterior_predictive_svi_one_hot(): pseudocounts = torch.ones(3) * 0.1 true_probs = torch.tensor([0.15, 0.6, 0.25]) classes = dist.OneHotCategorical(true_probs).sample((10000,)) guide = AutoDelta(one_hot_model) svi = SVI(one_hot_model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO()) for i in range(1000): svi.step(pseudocounts, classes=classes) posterior_samples = Predictive(guide, num_samples=10000).get_samples(pseudocounts) posterior_predictive = Predictive(one_hot_model, posterior_samples) marginal_return_vals = posterior_predictive.get_samples(pseudocounts)["obs"] assert_close(marginal_return_vals.mean(dim=0), true_probs.unsqueeze(0), rtol=0.1)
def guide_ncls(self, xs, ys=None): with pyro.plate("data"): z1_mean, z1_std = self.encoder_z1(xs) zs1 = pyro.sample("z1", dist.Normal(z1_mean, z1_std).to_event(1)) if ys is None: alpha = self.encoder_y(zs1) ys = pyro.sample("y", dist.OneHotCategorical(alpha)) z2_mean, z2_std = self.encoder_z2(zs1, ys) zs2 = pyro.sample("z2", dist.Normal(z2_mean, z2_std).to_event(1)) pass
def guide(self, data, label=None): pyro.module("encoder", self) with pyro.plate("data"): label_prior, deformation_loc, deformation_scale = self.encoder( data) pyro.sample("label", dist.OneHotCategorical(label_prior)) pyro.sample("deformation", dist.Normal(deformation_loc, deformation_scale))
def test_support_dims(dim, probs): probs = modify_params_using_dims(probs, dim) d = dist.OneHotCategorical(probs) support = d.enumerate_support() for s in support: assert_correct_dimensions(s, probs) n = len(support) assert support.shape == (n,) + d.batch_shape + d.event_shape support_expanded = d.enumerate_support(expand=True) assert support_expanded.shape == (n,) + d.batch_shape + d.event_shape support_unexpanded = d.enumerate_support(expand=False) assert support_unexpanded.shape == (n,) + (1,) * len(d.batch_shape) + d.event_shape
def model(self, xs, ys=None): """ The model corresponds to the following generative process: p(z) = normal(0,I) # handwriting style (latent) p(y|x) = categorical(I/10.) # which digit (semi-supervised) p(x|y,z) = bernoulli(mu(y,z)) # an image mu is given by a neural network `decoder` :param xs: a batch of scaled vectors of pixels from an image :param ys: (optional) a batch of the class labels i.e. the digit corresponding to the image(s) :return: None """ # register this pytorch module and all of its sub-modules with pyro pyro.module("ss_vae", self) # inform Pyro that the variables in the batch of xs, ys are conditionally independent batch_size = xs.size(0) with pyro.iarange("independent"): # sample the handwriting style from the constant prior distribution prior_mu = Variable(torch.zeros([batch_size, self.z_dim])) prior_sigma = Variable(torch.ones([batch_size, self.z_dim])) zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma).reshape(extra_event_dims=1)) # if the label y (which digit to write) is supervised, sample from the # constant prior, otherwise, observe the value (i.e. score it against the constant prior) alpha_prior = Variable(torch.ones([batch_size, self.y_dim]) / (1.0 * self.y_dim)) if ys is None: ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior)) else: pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys) # finally, score the image (x) using the handwriting style (z) and # the class label y (which digit to write) against the # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z)) # where `decoder` is a neural network mu = self.decoder.forward(zs, ys) pyro.sample("x", dist.Bernoulli(mu).reshape(extra_event_dims=1), obs=xs)
def model_classify(self, xs, ys=None): """ this model is used to add an auxiliary (supervised) loss as described in the Kingma et al., "Semi-Supervised Learning with Deep Generative Models". """ # register all pytorch (sub)modules with pyro pyro.module("ss_vae", self) # inform Pyro that the variables in the batch of xs, ys are conditionally independent with pyro.plate("data"): # this here is the extra term to yield an auxiliary loss that we do gradient descent on if ys is not None: alpha = self.encoder_y.forward(xs) with pyro.poutine.scale(scale=self.aux_loss_multiplier): pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
def model_aux(self, xs: torch.Tensor, ys: Optional[torch.Tensor] = None, **kwargs: float) -> None: """ Models an auxiliary (supervised) loss """ pyro.module("ss_vae", self) with pyro.plate("data"): # the extra term to yield an auxiliary loss aux_loss_multiplier = kwargs.get("aux_loss_multiplier", 20) if ys is not None: alpha = self.encoder_y.forward(xs) with pyro.poutine.scale(scale=aux_loss_multiplier): pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
def test_auxsvi_trainer_cls(invariances): data_dim = (5, 8, 8) train_unsup = torch.randn(data_dim[0], torch.prod(tt(data_dim[1:])).item()) train_sup = train_unsup + .1 * torch.randn_like(train_unsup) labels = dist.OneHotCategorical(torch.ones(data_dim[0], 3)).sample() loader_unsup, loader_sup, loader_val = utils.init_ssvae_dataloaders( train_unsup, (train_sup, labels), (train_sup, labels), batch_size=2) vae = models.ssiVAE(data_dim[1:], 2, 3, invariances) trainer = trainers.auxSVItrainer(vae) weights_before = dc(vae.state_dict()) for _ in range(2): trainer.step(loader_unsup, loader_sup, loader_val) weights_after = vae.state_dict() assert_(not torch.isnan(tt(trainer.history["training_loss"])).any()) assert_(not assert_weights_equal(weights_before, weights_after))
def model(self, batch, reversed_batch, batch_mask, batch_seqlens, kl_anneal=1.0): """ the model defines p(x_{1:T}|z_{1:T}) and p(z_{1:T}) """ # maximum duration of batch Tmax = batch.size(1) # register torch submodules w/ pyro pyro.module("dmm", self) # setup recursive conditioning for p(z_t|z_{t-1}) z_prev = self.z_0.expand(batch.size(0), self.z_0.size(0)) # sample conditionally indepdent text across the batch with pyro.plate("z_batch", len(batch)): # sample latent vars z and observed x w/ multiple samples from the guide for each z for t in pyro.markov(range(1, Tmax + 1)): # compute params of diagonal gaussian p(z_t|z_{t-1}) z_loc, z_scale = self.transition(z_prev) # sample latent variable with poutine.scale(scale=kl_anneal): z_t = pyro.sample( "z_%d" % t, dist.Normal(z_loc, z_scale).mask( batch_mask[:, t - 1:t]).to_event(1), ) # compute emission probability from latent variable emission_prob = self.emitter(z_t) # observe x_t according to the Categorical distribution defined by the emitter probability pyro.sample( "obs_x_%d" % t, dist.OneHotCategorical(emission_prob).mask( batch_mask[:, t - 1:t]).to_event(1), obs=batch[:, t - 1, :], ) # set conditional var for next time step z_prev = z_t pass