Ejemplo n.º 1
0
 def make_latent_distributions(self, batch, mask, no_proposal=False):
     """
     Make latent distributions for the given batch and mask.
     No no_proposal is True, return None instead of proposal distribution.
     """
     observed = self.make_observed(batch, mask)
     if no_proposal:
         proposal = None
     else:
         full_info = torch.cat([batch, mask], 1)
         proposal_params = self.proposal_network(full_info)
         proposal = normal_parse_params(proposal_params, 1e-3)
     prior_params = self.prior_network(torch.cat([observed, mask], 1))
     prior = normal_parse_params(prior_params, 1e-3)
     return proposal, prior
Ejemplo n.º 2
0
    def make_latent_distributions(self, batch):
        """
        Make latent distributions for the given batch and mask.
        No no_proposal is True, return None instead of proposal distribution.
        """

        prior_params = self.prior_network(batch)
        prior = normal_parse_params(prior_params, 1e-3)
        return prior
Ejemplo n.º 3
0
    def make_latent_distributions(self, latent_params, no_proposal=False):
        """
		Make latent distributions for the given batch and mask.
		No no_proposal is True, return None instead of proposal distribution.
		"""

        # print(latent_params.shape, "=================")
        latent = normal_parse_params(latent_params, 1e-3)

        return latent
Ejemplo n.º 4
0
 def make_latent_distributions(self,
                               batch,
                               mask,
                               no_proposal=False,
                               need_observed=True):
     """
     Make latent distributions for the given batch and mask.
     No no_proposal is True, return None instead of proposal distribution.
     """
     if need_observed:
         observed = self.make_observed(batch, mask)
     else:
         observed = batch
     # plt.imshow(observed.detach().cpu().numpy()[0,0], cmap='gray')
     # plt.show()
     if no_proposal:
         proposal = None
     else:
         full_info = torch.cat([batch, mask], 1)
         proposal_params = self.proposal_network(full_info)
         proposal = normal_parse_params(proposal_params, 1e-3)
     prior_params = self.prior_network(torch.cat([observed, mask], 1))
     prior = normal_parse_params(prior_params, 1e-3)
     return proposal, prior
Ejemplo n.º 5
0
def sampler(params):
    return normal_parse_params(params).mean
Ejemplo n.º 6
0
def sampler(params, multiple=False):
    return normal_parse_params(params, multiple=False).mean