def sample(self, tgt_sents: torch.Tensor, tgt_masks: torch.Tensor, src_enc: torch.Tensor, src_masks: torch.Tensor, nsamples: int =1, random=True) -> Tuple[torch.Tensor, torch.Tensor]: mu, logvar = self.core(tgt_sents, tgt_masks, src_enc, src_masks) z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, nsamples=nsamples, random=random) log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks) return z, log_probs
def init(self, tgt_sents, tgt_masks, src_enc, src_masks, init_scale=1.0, init_mu=True, init_var=True) -> Tuple[torch.Tensor, torch.Tensor]: mu, logvar = self.core.init(tgt_sents, tgt_masks, src_enc, src_masks, init_scale=init_scale, init_mu=init_mu, init_var=init_var) z, eps = Posterior.reparameterize(mu, logvar, tgt_masks, random=True) log_probs = Posterior.log_probability(z, eps, mu, logvar, tgt_masks) z = z.squeeze(1) log_probs = log_probs.squeeze(1) return z, log_probs