Ejemplo n.º 1
0
    def forward(self, input: torch.Tensor):
        encoder_output = self.proposal_network(input)

        encoded_shape = encoder_output.shape
        encoded_resolution = encoded_shape[-1] * encoded_shape[-2]

        # rearrange feature maps into "latent space"
        rearranged_encoder_output = encoder_output \
            .flatten(start_dim=-2) \
            .transpose(-1, -2) \
            .flatten(end_dim=1)

        proposal_logits = self.proposal_logits_head(rearranged_encoder_output)

        proposal_distribution = distributions.RelaxedOneHotCategorical(
            self.temperature, logits=proposal_logits)
        proposal_sample = proposal_distribution.rsample()
        proposal_sample_copy = proposal_sample
        proposal_sample = proposal_sample.reshape(-1, encoded_resolution, self.latent_dim) \
            .transpose(-1, -2) \
            .reshape(encoded_shape)

        reconstruction = self.generative_network(proposal_sample)

        return reconstruction, proposal_distribution, proposal_sample_copy
Ejemplo n.º 2
0
 def forward(self, state_features):
     x = self.feedforward_model(state_features)
     if self.dist == 'tanh_normal':
         mean, std = th.chunk(x, 2, -1)
         mean = self.mean_scale * th.tanh(mean / self.mean_scale)
         std = F.softplus(std + self.raw_init_std) + self.min_std
         dist = td.Normal(mean, std)
         # TODO: fix nan problem
         dist = td.TransformedDistribution(dist,
                                           td.TanhTransform(cache_size=1))
         dist = td.Independent(dist, 1)
         dist = SampleDist(dist)
     elif self.dist == 'trunc_normal':
         mean, std = th.chunk(x, 2, -1)
         std = 2 * th.sigmoid((std + self.raw_init_std) / 2) + self.min_std
         from rls.nn.dists.TruncatedNormal import \
             TruncatedNormal as TruncNormalDist
         dist = TruncNormalDist(th.tanh(mean), std, -1, 1)
         dist = td.Independent(dist, 1)
     elif self.dist == 'one_hot':
         dist = td.OneHotCategoricalStraightThrough(logits=x)
     elif self.dist == 'relaxed_one_hot':
         dist = td.RelaxedOneHotCategorical(th.tensor(0.1), logits=x)
     else:
         raise NotImplementedError(f"{self.dist} is not implemented.")
     return dist
Ejemplo n.º 3
0
    def forward(self, input: torch.Tensor,
                proposal: distributions.RelaxedOneHotCategorical,
                proposal_sample: torch.Tensor,
                reconstruction: torch.Tensor) -> torch.Tensor:

        if self.likelihood == 'bernoulli':
            likelihood = distributions.Bernoulli(probs=reconstruction)
        else:
            likelihood = distributions.Normal(reconstruction,
                                              torch.ones_like(reconstruction))

        likelihood = distributions.Independent(likelihood,
                                               reinterpreted_batch_ndims=-1)
        reconstruction_loss = likelihood.log_prob(input).mean()

        assert proposal.logits.dim(
        ) == 2, "proposal.shape == [*, D], D is shape of isotopic gaussian"

        prior = distributions.RelaxedOneHotCategorical(proposal.temperature,
                                                       logits=torch.ones_like(
                                                           proposal.logits))
        regularization = (proposal.log_prob(proposal_sample) - prior.log_prob(proposal_sample)) \
            .mean()

        # evidence lower bound (maximize)
        total_loss = reconstruction_loss - self.beta * regularization

        return -total_loss, -reconstruction_loss, regularization
 def sample_q(self, k, mode):
     if mode == ModeKeys.TRAIN:
         z_dist = td.RelaxedOneHotCategorical(self.temp,
                                              logits=self.q_dist.logits)
         z_NK = z_dist.rsample((k, ))
     elif mode == ModeKeys.EVAL:
         z_NK = self.q_dist.sample((k, ))
     return torch.reshape(z_NK, (k, -1, self.z_dim))
Ejemplo n.º 5
0
    def __init__(self, n_clusters, cluster_to_params_graph):

        super().__init__()
        self.cluster_logits = nn.Parameter(torch.randn(n_clusters))
        self.cluster_to_params_graph = cluster_to_params_graph
        self.cluster_temperature_logit = nn.Parameter(torch.tensor(1.))
        self.cluster_distr = td.RelaxedOneHotCategorical(
            temperature=F.softplus(self.cluster_temperature_logit),
            logits=self.cluster_logits)
 def _build_posterior(self, param_loc, param_spi_scale):
     tn_posterior = TruncatedNormal(
         loc=torch.clamp(param_loc, self.min_size - 1, self.max_size + 1),
         scale=torch.clamp_min(nn.functional.softplus(param_spi_scale),
                               1e-3),
         low=self.min_size,
         high=self.max_size)
     logits_posterior = tn_posterior.log_prob(self._categories)
     dist_posterior = distributions.RelaxedOneHotCategorical(
         logits=torch.clamp_min(logits_posterior, _LOG_EPSILON),
         temperature=self.temperature)
     return dist_posterior
 def _build_prior(self, prior_loc, prior_spi_scale, prior_temperature):
     tn_prior = TruncatedNormal(loc=prior_loc,
                                scale=torch.clamp_min(
                                    nn.functional.softplus(prior_spi_scale),
                                    1e-3),
                                low=self.min_size,
                                high=self.max_size)
     logits_prior = tn_prior.log_prob(self._categories)
     dist_prior = distributions.RelaxedOneHotCategorical(
         logits=torch.clamp_min(logits_prior, _LOG_EPSILON),
         temperature=torch.tensor(prior_temperature, dtype=torch.float32))
     return dist_prior
Ejemplo n.º 8
0
    def __init__(self, temp, latent_num, latent_dim):
        super(Model, self).__init__()
        if type(temp) != torch.Tensor:
            temp = torch.tensor(temp)
        self.__temp = temp
        self.latent_num = latent_num
        self.latent_dim = latent_dim
        self.encoder = Encoder(latent_num=latent_num, latent_dim=latent_dim)
        self.decoder = Decoder(latent_num=latent_num, latent_dim=latent_dim)
        if 'ExpTDModel' in  str(self.__class__):
            self.prior = ExpRelaxedCategorical(temp, probs=torch.ones(latent_dim).cuda())
        else:
            self.prior = dist.RelaxedOneHotCategorical(temp, probs=torch.ones(latent_dim).cuda())
        self.initialize()

        self.softmax = nn.Softmax(dim=-1)
Ejemplo n.º 9
0
    def forward(self, input: torch.Tensor):
        encoder_output = self.proposal_network(input)
        bottleneck_resolution = encoder_output.size()[-2:]
        encoder_output = self.prenet(encoder_output)

        assert encoder_output.size(-1) == self.latent_dim * self.num_latents

        proposal_logits = self.proposal_logits_head(encoder_output) \
            .reshape(-1, self.num_latents, self.latent_dim) \
            .flatten(end_dim=-2)

        proposal_distribution = distributions.RelaxedOneHotCategorical(
            self.temperature, logits=proposal_logits)
        proposal_sample = proposal_distribution.rsample()
        proposal_sample_copy = proposal_sample
        proposal_sample = proposal_sample.reshape(
            -1, self.num_latents * self.latent_dim, 1, 1)
        proposal_sample = F.interpolate(proposal_sample, bottleneck_resolution)

        reconstruction = self.generative_network(proposal_sample)

        return reconstruction, proposal_distribution, proposal_sample_copy
Ejemplo n.º 10
0
 def _sample_categorical(self, probs):
     if self.test_mode:
         return dist.OneHotCategorical(probs=probs).sample()
     else:
         return dist.RelaxedOneHotCategorical(temperature=self.temperature,
                                              probs=probs).sample()
Ejemplo n.º 11
0
 def sample(self, log_alpha, temp):
     v_dist = dist.RelaxedOneHotCategorical(temp, logits=log_alpha)
     concrete = v_dist.rsample()
     return concrete, v_dist
Ejemplo n.º 12
0
 def temp(self, value):
     self.__temp = value
     if 'ExpTDModel' in  str(self.__class__):
         self.prior = ExpRelaxedCategorical(value, probs=torch.ones(self.latent_dim).cuda())
     else:
         self.prior = dist.RelaxedOneHotCategorical(value, probs=torch.ones(self.latent_dim).cuda())