Example #1
0
class IndependentNormal(Distribution):
    arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
    support = constraints.positive
    has_rsample = True

    def __init__(self, loc, scale, validate_args=None):
        self.base_dist = Independent(Normal(loc=loc,
                                            scale=scale,
                                            validate_args=validate_args),
                                     len(loc.shape) - 1,
                                     validate_args=validate_args)
        super(IndependentNormal, self).__init__(self.base_dist.batch_shape,
                                                self.base_dist.event_shape,
                                                validate_args=validate_args)

    def log_prob(self, value):
        return self.base_dist.log_prob(value)

    @property
    def mean(self):
        return self.base_dist.mean

    @property
    def variance(self):
        return self.base_dist.variance

    def sample(self, sample_shape=torch.Size()):
        return self.base_dist.sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        return self.base_dist.rsample(sample_shape)

    def entropy(self):
        entropy = self.base_dist.entropy()
        return entropy
Example #2
0
 def test_independent_shape(self):
     for Dist, params in EXAMPLES:
         for param in params:
             base_dist = Dist(**param)
             x = base_dist.sample()
             base_log_prob_shape = base_dist.log_prob(x).shape
             for reinterpreted_batch_ndims in range(
                     len(base_dist.batch_shape) + 1):
                 indep_dist = Independent(base_dist,
                                          reinterpreted_batch_ndims)
                 indep_log_prob_shape = base_log_prob_shape[:len(
                     base_log_prob_shape) - reinterpreted_batch_ndims]
                 self.assertEqual(
                     indep_dist.log_prob(x).shape, indep_log_prob_shape)
                 self.assertEqual(indep_dist.sample().shape,
                                  base_dist.sample().shape)
                 self.assertEqual(indep_dist.has_rsample,
                                  base_dist.has_rsample)
                 if indep_dist.has_rsample:
                     self.assertEqual(indep_dist.sample().shape,
                                      base_dist.sample().shape)
                 try:
                     self.assertEqual(
                         indep_dist.enumerate_support().shape,
                         base_dist.enumerate_support().shape,
                     )
                     self.assertEqual(indep_dist.mean.shape,
                                      base_dist.mean.shape)
                 except NotImplementedError:
                     pass
                 try:
                     self.assertEqual(indep_dist.variance.shape,
                                      base_dist.variance.shape)
                 except NotImplementedError:
                     pass
                 try:
                     self.assertEqual(indep_dist.entropy().shape,
                                      indep_log_prob_shape)
                 except NotImplementedError:
                     pass
Example #3
0
class IndependentRescaledBeta(Distribution):
    arg_constraints = {
        'concentration1': constraints.positive,
        'concentration0': constraints.positive
    }
    support = constraints.interval(-1., 1.)
    has_rsample = True

    def __init__(self, concentration1, concentration0, validate_args=None):
        self.base_dist = Independent(RescaledBeta(concentration1,
                                                  concentration0,
                                                  validate_args),
                                     len(concentration1.shape) - 1,
                                     validate_args=validate_args)
        super(IndependentRescaledBeta,
              self).__init__(self.base_dist.batch_shape,
                             self.base_dist.event_shape,
                             validate_args=validate_args)

    def log_prob(self, value):
        return self.base_dist.log_prob(value)

    @property
    def mean(self):
        return self.base_dist.mean

    @property
    def variance(self):
        return self.base_dist.variance

    def sample(self, sample_shape=torch.Size()):
        return self.base_dist.sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        return self.base_dist.rsample(sample_shape)

    def entropy(self):
        entropy = self.base_dist.entropy()
        return entropy
Example #4
0
    def __call__(self, x, out_keys=['action'], info={}, **kwargs):
        # Output dictionary
        out_policy = {}

        # Forward pass of feature networks to obtain features
        if self.recurrent:
            out_network = self.network(x=x,
                                       hidden_states=self.rnn_states,
                                       mask=info.get('mask', None))
            features = out_network['output']
            # Update the tracking of current RNN hidden states
            self.rnn_states = out_network['hidden_states']
        else:
            features = self.network(x)

        # Forward pass through mean head to obtain mean values for Gaussian distribution
        mean = self.network.mean_head(features)
        # Obtain logvar based on the options
        if isinstance(self.network.logvar_head,
                      nn.Linear):  # linear layer, then do forward pass
            logvar = self.network.logvar_head(features)
        else:  # either Tensor or nn.Parameter
            logvar = self.network.logvar_head
            # Expand as same shape as mean
            logvar = logvar.expand_as(mean)

        # Forward pass of value head to obtain value function if required
        if 'state_value' in out_keys:
            out_policy['state_value'] = self.network.value_head(
                features).squeeze(-1)  # squeeze final single dim

        # Get std from logvar
        if self.std_style == 'exp':
            std = torch.exp(0.5 * logvar)
        elif self.std_style == 'softplus':
            std = F.softplus(logvar)

        # Lower bound threshould for std
        min_std = torch.full(std.size(),
                             self.min_std).type_as(std).to(self.device)
        std = torch.max(std, min_std)

        # Create independent Gaussian distributions i.e. Diagonal Gaussian
        action_dist = Independent(Normal(loc=mean, scale=std), 1)

        # Sample action from the distribution (no gradient)
        # Do not use `rsample()`, it leads to zero gradient of mean head !
        action = action_dist.sample()
        out_policy['action'] = action

        # Calculate log-probability of the sampled action
        if 'action_logprob' in out_keys:
            out_policy['action_logprob'] = action_dist.log_prob(action)

        # Calculate policy entropy conditioned on state
        if 'entropy' in out_keys:
            out_policy['entropy'] = action_dist.entropy()

        # Calculate policy perplexity i.e. exp(entropy)
        if 'perplexity' in out_keys:
            out_policy['perplexity'] = action_dist.perplexity()

        # sanity check for NaN
        if torch.any(torch.isnan(action)):
            while True:
                msg = 'NaN ! A workaround is to learn state-independent std or use tanh rather than relu'
                msg2 = f'check: \n\t mean: {mean}, logvar: {logvar}'
                print(msg + msg2)

        # Constraint action in valid range
        out_policy['action'] = self.constraint_action(action)

        return out_policy