Esempio n. 1
0
    def get_latent_dist(self, obs):
        """Args:
            obs: tensor of shape [batch_size, obs_dim] of values in {0, 1}
                (only applicable if layer_idx = num_stochastic_layers - 1)

        Returns: distribution for all latent layers:

            dist.sample(sample_shape=[sample_shape]) returns
            (latent_0, ..., latent_N) where each latent_n
            is of shape [sample_shape, latent_dim] and latent_0
            corresponds to the latent furthest away from obs

            if latent_n is of shape [batch_shape, latent_dim]
            dist.log_prob(latent_0, ..., latent_N) returns
            sum_n log_prob(latent_n) which is of shape [batch_shape]"""

        latent_dist = util.ChainDistributionFromSingle(
            self.get_latent_layer_dist(layer_idx=self.num_stochastic_layers - 1, obs=obs))
        for layer_idx in reversed(range(self.num_stochastic_layers - 1)):
            # be careful about closures
            # https://stackoverflow.com/questions/2295290/what-do-lambda-function-closures-capture/2295372
            latent_dist = util.ChainDistribution(
                latent_dist,
                lambda previous_latent_layer, layer_idx=layer_idx: self.get_latent_layer_dist(
                    layer_idx=layer_idx, previous_latent_layer=previous_latent_layer))
        return util.ReversedChainDistribution(latent_dist)
Esempio n. 2
0
 def get_prop_network(self):
     latent_dist = util.ChainDistributionFromSingle(
         self.get_prop_latent_layer(layer_idx=self.num_stochastic_layers - 1))
     for layer_idx in reversed(range(self.num_stochastic_layers - 1)):
         # be careful about closures
         # https://stackoverflow.com/questions/2295290/what-do-lambda-function-closures-capture/2295372
         latent_dist = util.ChainDistribution(
             latent_dist,
             lambda previous_latent_layer, layer_idx=layer_idx: self.get_prop_latent_layer(
                 layer_idx=layer_idx, previous_latent_layer=previous_latent_layer))
     return util.ReversedChainDistribution(latent_dist)
Esempio n. 3
0
    def get_latent_dist(self):
        """Returns: distribution for all latent layers:

            dist.sample(sample_shape=[sample_shape]) returns
            (latent_0, ..., latent_N) where each latent_n
            is of shape [sample_shape, latent_dim] and latent_0
            corresponds to the latent furthest away from obs

            if latent_n is of shape [batch_shape, latent_dim]
            dist.log_prob(latent_0, ..., latent_N) returns
            sum_n log_prob(latent_n) which is of shape [batch_shape]"""

        latent_dist = util.ChainDistributionFromSingle(self.get_latent_layer_dist(layer_idx=0))
        for layer_idx in range(1, self.num_stochastic_layers):
            # be careful about closures
            # https://stackoverflow.com/questions/2295290/what-do-lambda-function-closures-capture/2295372
            latent_dist = util.ChainDistribution(
                latent_dist,
                lambda previous_latent_layer, layer_idx=layer_idx: self.get_latent_layer_dist(
                    layer_idx=layer_idx, previous_latent_layer=previous_latent_layer))
        return latent_dist