class IndependentNormal(Distribution): arg_constraints = {'loc': constraints.real, 'scale': constraints.positive} support = constraints.positive has_rsample = True def __init__(self, loc, scale, validate_args=None): self.base_dist = Independent(Normal(loc=loc, scale=scale, validate_args=validate_args), len(loc.shape) - 1, validate_args=validate_args) super(IndependentNormal, self).__init__(self.base_dist.batch_shape, self.base_dist.event_shape, validate_args=validate_args) def log_prob(self, value): return self.base_dist.log_prob(value) @property def mean(self): return self.base_dist.mean @property def variance(self): return self.base_dist.variance def sample(self, sample_shape=torch.Size()): return self.base_dist.sample(sample_shape) def rsample(self, sample_shape=torch.Size()): return self.base_dist.rsample(sample_shape) def entropy(self): entropy = self.base_dist.entropy() return entropy
def test_independent_shape(self): for Dist, params in EXAMPLES: for param in params: base_dist = Dist(**param) x = base_dist.sample() base_log_prob_shape = base_dist.log_prob(x).shape for reinterpreted_batch_ndims in range( len(base_dist.batch_shape) + 1): indep_dist = Independent(base_dist, reinterpreted_batch_ndims) indep_log_prob_shape = base_log_prob_shape[:len( base_log_prob_shape) - reinterpreted_batch_ndims] self.assertEqual( indep_dist.log_prob(x).shape, indep_log_prob_shape) self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape) self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample) if indep_dist.has_rsample: self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape) try: self.assertEqual( indep_dist.enumerate_support().shape, base_dist.enumerate_support().shape, ) self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape) except NotImplementedError: pass try: self.assertEqual(indep_dist.variance.shape, base_dist.variance.shape) except NotImplementedError: pass try: self.assertEqual(indep_dist.entropy().shape, indep_log_prob_shape) except NotImplementedError: pass
class IndependentRescaledBeta(Distribution): arg_constraints = { 'concentration1': constraints.positive, 'concentration0': constraints.positive } support = constraints.interval(-1., 1.) has_rsample = True def __init__(self, concentration1, concentration0, validate_args=None): self.base_dist = Independent(RescaledBeta(concentration1, concentration0, validate_args), len(concentration1.shape) - 1, validate_args=validate_args) super(IndependentRescaledBeta, self).__init__(self.base_dist.batch_shape, self.base_dist.event_shape, validate_args=validate_args) def log_prob(self, value): return self.base_dist.log_prob(value) @property def mean(self): return self.base_dist.mean @property def variance(self): return self.base_dist.variance def sample(self, sample_shape=torch.Size()): return self.base_dist.sample(sample_shape) def rsample(self, sample_shape=torch.Size()): return self.base_dist.rsample(sample_shape) def entropy(self): entropy = self.base_dist.entropy() return entropy
def __call__(self, x, out_keys=['action'], info={}, **kwargs): # Output dictionary out_policy = {} # Forward pass of feature networks to obtain features if self.recurrent: out_network = self.network(x=x, hidden_states=self.rnn_states, mask=info.get('mask', None)) features = out_network['output'] # Update the tracking of current RNN hidden states self.rnn_states = out_network['hidden_states'] else: features = self.network(x) # Forward pass through mean head to obtain mean values for Gaussian distribution mean = self.network.mean_head(features) # Obtain logvar based on the options if isinstance(self.network.logvar_head, nn.Linear): # linear layer, then do forward pass logvar = self.network.logvar_head(features) else: # either Tensor or nn.Parameter logvar = self.network.logvar_head # Expand as same shape as mean logvar = logvar.expand_as(mean) # Forward pass of value head to obtain value function if required if 'state_value' in out_keys: out_policy['state_value'] = self.network.value_head( features).squeeze(-1) # squeeze final single dim # Get std from logvar if self.std_style == 'exp': std = torch.exp(0.5 * logvar) elif self.std_style == 'softplus': std = F.softplus(logvar) # Lower bound threshould for std min_std = torch.full(std.size(), self.min_std).type_as(std).to(self.device) std = torch.max(std, min_std) # Create independent Gaussian distributions i.e. Diagonal Gaussian action_dist = Independent(Normal(loc=mean, scale=std), 1) # Sample action from the distribution (no gradient) # Do not use `rsample()`, it leads to zero gradient of mean head ! action = action_dist.sample() out_policy['action'] = action # Calculate log-probability of the sampled action if 'action_logprob' in out_keys: out_policy['action_logprob'] = action_dist.log_prob(action) # Calculate policy entropy conditioned on state if 'entropy' in out_keys: out_policy['entropy'] = action_dist.entropy() # Calculate policy perplexity i.e. exp(entropy) if 'perplexity' in out_keys: out_policy['perplexity'] = action_dist.perplexity() # sanity check for NaN if torch.any(torch.isnan(action)): while True: msg = 'NaN ! A workaround is to learn state-independent std or use tanh rather than relu' msg2 = f'check: \n\t mean: {mean}, logvar: {logvar}' print(msg + msg2) # Constraint action in valid range out_policy['action'] = self.constraint_action(action) return out_policy