def _compute_loss_world(self, state, data): # unpackage data beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = state observations, rewards, nonterminals = data observation_loss = F.mse_loss( bottle(self.observation_model, (beliefs, posterior_states)), observations, reduction='none').sum( dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1)) reward_loss = F.mse_loss(bottle(self.reward_model, (beliefs, posterior_states)), rewards, reduction='none').mean(dim=(0, 1)) # TODO: 5 # transition loss kl_loss = torch.max( kl_divergence( Independent(Normal(posterior_means, posterior_std_devs), 1), Independent(Normal(prior_means, prior_std_devs), 1)), self.free_nats).mean(dim=(0, 1)) if self.args.pcont: pcont_loss = F.binary_cross_entropy( bottle(self.pcont_model, (beliefs, posterior_states)), nonterminals) return observation_loss, self.args.reward_scale * reward_loss, kl_loss, ( self.args.pcont_scale * pcont_loss if self.args.pcont else 0)
def choose_action(self, state): state = torch.from_numpy(state).float().unsqueeze(0) mean, logstd = self.policy(state.cuda()) dist = Independent(Normal(mean.squeeze(), torch.exp(logstd)), 1) action = dist.sample() log_prob = dist.log_prob(action) return action.squeeze().cpu().numpy(), log_prob.item()
def get_logp(mean, std, action, expand=None): if expand is not None: dist = Independent(Normal(mean, std), reinterpreted_batch_ndims=1).expand(expand) else: dist = Independent(Normal(mean, std), reinterpreted_batch_ndims=1) return dist.log_prob(action)
def get_log_prob_entropy(self, x, actions): action_mean, action_log_std, action_std = self.forward(x) #[batch_size, a_dim] normal = Normal(loc=action_mean, scale=action_std) diagn = Independent(normal, 1) log_prob = diagn.log_prob(actions).unsqueeze(dim=1) entropy = diagn.entropy()[0] #prob = MultivariateNormal(loc=action_mean, scale_tril=torch.diag(action_std[0,:]**2)) #log_prob = prob.log_prob(actions).unsqueeze(dim=1) #entropy = prob.entropy()[0] return log_prob, entropy
def forward(self, state): outputs = super().forward(state) action_dim = outputs.shape[-1] // 2 means = outputs[..., 0:action_dim] logvars = outputs[..., action_dim:] std = (0.5 * logvars).exp_() return Independent(Normal(means + self._center, std * self._scale), 1)
def __init__(self, base_distribution, transforms, validate_args=None): if isinstance(transforms, Transform): self.transforms = [transforms, ] elif isinstance(transforms, list): if not all(isinstance(t, Transform) for t in transforms): raise ValueError("transforms must be a Transform or a list of Transforms") self.transforms = transforms else: raise ValueError("transforms must be a Transform or list, but was {}".format(transforms)) # Reshape base_distribution according to transforms. base_shape = base_distribution.batch_shape + base_distribution.event_shape base_event_dim = len(base_distribution.event_shape) transform = ComposeTransform(self.transforms) domain_event_dim = transform.domain.event_dim if len(base_shape) < domain_event_dim: raise ValueError("base_distribution needs to have shape with size at least {}, but got {}." .format(domain_event_dim, base_shape)) shape = transform.forward_shape(base_shape) expanded_base_shape = transform.inverse_shape(shape) if base_shape != expanded_base_shape: base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim] base_distribution = base_distribution.expand(base_batch_shape) reinterpreted_batch_ndims = domain_event_dim - base_event_dim if reinterpreted_batch_ndims > 0: base_distribution = Independent(base_distribution, reinterpreted_batch_ndims) self.base_dist = base_distribution # Compute shapes. event_dim = transform.codomain.event_dim + max(base_event_dim - domain_event_dim, 0) assert len(shape) >= event_dim cut = len(shape) - event_dim batch_shape = shape[:cut] event_shape = shape[cut:] super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def forward(self, *inputs): """Forward method. Args: *inputs: Input to the module. Returns: torch.distributions.independent.Independent: Independent distribution. """ mean, log_std_uncentered = self._get_mean_and_log_std(*inputs) if self._min_std_param or self._max_std_param: log_std_uncentered = log_std_uncentered.clamp( min=(None if self._min_std_param is None else self._min_std_param.item()), max=(None if self._max_std_param is None else self._max_std_param.item())) if self._std_parameterization == 'exp': std = log_std_uncentered.exp() else: std = log_std_uncentered.exp().exp().add(1.).log() dist = self._norm_dist_class(mean, std) # This control flow is needed because if a TanhNormal distribution is # wrapped by torch.distributions.Independent, then custom functions # such as rsample_with_pretanh_value of the TanhNormal distribution # are not accessable. if not isinstance(dist, TanhNormal): # Makes it so that a sample from the distribution is treated as a # single sample and not dist.batch_shape samples. dist = Independent(dist, 1) return dist
def forward(self, *inputs): """Forward method. Args: *inputs: Input to the module. Returns: torch.Tensor: Module output. """ mean, log_std_uncentered = self._get_mean_and_log_std(*inputs) if self._min_std_param or self._max_std_param: log_std_uncentered = log_std_uncentered.clamp( min=self._to_scalar_if_not_none(self._min_std_param), max=self._to_scalar_if_not_none(self._max_std_param)) if self._std_parameterization == 'exp': std = log_std_uncentered.exp() else: std = log_std_uncentered.exp().exp().add(1.).log() dist = Independent(Normal(mean, std), 1) return dist
def forward(self, state): outputs = super().forward(state) action_dim = outputs.shape[1] // 2 means = self._squash(torch.tanh(outputs[:, 0:action_dim])) logvars = outputs[:, action_dim:] * self._scale std = logvars.exp_() return Independent(Normal(means, std), 1)
def forward(self, *inputs): """Forward method. Args: *inputs: Input to the module. Returns: torch.Tensor: Module output. """ mean, log_std_uncentered = self._get_mean_and_log_std(*inputs) if self._min_std_param or self._max_std_param: log_std_uncentered = log_std_uncentered.clamp( min=self._to_scalar_if_not_none(self._min_std_param), max=self._to_scalar_if_not_none(self._max_std_param)) if self._std_parameterization == 'exp': std = log_std_uncentered.exp() else: std = log_std_uncentered.exp().exp().add(1.).log() dist = self._norm_dist_class(mean, std) # Makes it so that a sample from the distribution is treated as a # single sample and not dist.batch_shape samples. dist = Independent(dist, len(dist.batch_shape)) return dist
def spectral_density(self, smk) -> MixtureSameFamily: """Returns the Mixture of Gaussians thet model the spectral density of the provided spectral mixture kernel.""" mus = smk.mixture_means.detach().reshape(-1, 1) sigmas = smk.mixture_scales.detach().reshape(-1, 1) mix = Categorical(smk.mixture_weights.detach()) comp = Independent(Normal(mus, sigmas), 1) return MixtureSameFamily(mix, comp)
def _compute_loss_world(self, state, data): # unpackage data beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = state observations, rewards, nonterminals = data # observation_loss = F.mse_loss( # bottle(self.observation_model, (beliefs, posterior_states)), # observations[1:], # reduction='none').sum(dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1)) # # reward_loss = F.mse_loss( # bottle(self.reward_model, (beliefs, posterior_states)), # rewards[1:], # reduction='none').mean(dim=(0,1)) observation_loss = F.mse_loss( bottle(self.observation_model, (beliefs, posterior_states)), observations, reduction='none').sum( dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1)) reward_loss = F.mse_loss(bottle(self.reward_model, (beliefs, posterior_states)), rewards, reduction='none').mean(dim=(0, 1)) # TODO: 5 # transition loss kl_loss = torch.max( kl_divergence( Independent(Normal(posterior_means, posterior_std_devs), 1), Independent(Normal(prior_means, prior_std_devs), 1)), self.free_nats).mean(dim=(0, 1)) # print("check the reward", bottle(pcont_model, (beliefs, posterior_states)).shape, nonterminals[:-1].shape) if self.args.pcont: pcont_loss = F.binary_cross_entropy( bottle(self.pcont_model, (beliefs, posterior_states)), nonterminals) # pcont_pred = torch.distributions.Bernoulli(logits=bottle(self.pcont_model, (beliefs, posterior_states))) # pcont_loss = -pcont_pred.log_prob(nonterminals[1:]).mean(dim=(0, 1)) return observation_loss, self.args.reward_scale * reward_loss, kl_loss, ( self.args.pcont_scale * pcont_loss if self.args.pcont else 0)
def latent_priors(self): scale_prior_mu, scale_prior_sd = 3.0, 0.1 pos_prior_mu, pos_prior_sd = 0.0, 1.0 z_where_mu_prior = nn.Parameter(torch.FloatTensor( [scale_prior_mu, pos_prior_mu, pos_prior_mu]).expand(n, -1), requires_grad=False) z_where_sd_prior = nn.Parameter(torch.FloatTensor( [scale_prior_sd, pos_prior_sd, pos_prior_sd]).expand(n, -1), requires_grad=False) z_what_mu_prior = nn.Parameter(torch.zeros(50)) z_what_sd_prior = nn.Parameter(torch.ones(50)) z_where = Independent(Normal(z_where_mu_prior, z_where_sd_prior), 1).sample() z_what = Independent(Normal(z_what_mu_prior, z_what_sd_prior), 1).sample() z_pres = torch.ones(n, 1) return z_where, z_what, z_pres
def log_prob(self, value_mean, value_precision): if self._validate_args: self._validate_sample(value_mean) self._validate_sample(value_precision) if (value_precision <= 0).any(): raise ValueError("desired precision must be greater that 0") wishart_log_prob = DiagonalWishart(self.precision_diag, self.df).log_prob(value_precision) normal_log_prob = Independent( Normal(self.loc, (1 / (self.belief.unsqueeze(-1) * value_precision)).pow(0.5)), 1).log_prob(value_mean) return normal_log_prob + wishart_log_prob
def encode(self, x, z_where_prev, z_what_prev, z_pres_prev, h_prev, c_prev): kld_loss = 0 h, c = compute_hidden_state(self.rnn, x, z_where_prev, z_what_prev, z_pres_prev, h_prev, c_prev) z_pres_proba, z_where_mu, z_where_sd = self.predict(h) kld_loss += self.latent_loss(z_where_mu, z_where_sd) z_pres = Independent(Bernoulli(z_pres_proba * z_pres_prev), 1).sample() z_where = self._reparameterized_sample(z_where_mu, z_where_sd) x_att = attentive_stn_encode(z_where, x) z_what_mu, z_what_sd = self.obj_encode(x_att) kld_loss += self.latent_loss(z_what_mu, z_what_sd) z_what = self._reparameterized_sample(z_what_mu, z_what_sd) return z_where, z_what, z_pres, h, c, kld_loss
def forward(self, *inputs): mean, log_std_uncentered = self._get_mean_and_log_std(*inputs) if self._min_std_param or self._max_std_param: log_std_uncentered = log_std_uncentered.clamp( min=(None if self._min_std_param is None else self._min_std_param.item()), max=(None if self._max_std_param is None else self._max_std_param.item())) if self._std_parameterization == 'exp': std = log_std_uncentered.exp() else: std = log_std_uncentered.exp().exp().add(1.).log() dist = self._norm_dist_class(mean, std) # Makes it so that a sample from the distribution is treated as a # single sample and not dist.batch_shape samples. dist = Independent(dist, 1) return dist
def __init__(self, obs_shape, action_dim, hidden_size, inner_hidden, discrete_actions=False, base=None, base_kwargs=None): super().__init__() if base_kwargs is None: base_kwargs = {} if discrete_actions: self.hyper = False self.dist_type = Categorical self.base = AdvantageNetwork(obs_shape, hidden_size, action_dim) self.value_base = ValueNetwork(obs_shape, hidden_size) else: self.hyper = True self.base = HyperOption(obs_shape, hidden_size, action_dim, inner_hidden, **base_kwargs) self.value_base = ValueNetwork(obs_shape, hidden_size) self.dist_type = lambda mean, sigma: Independent( Normal(mean, sigma), 1)
def MultivariateNormalDiag(loc, scale_diag): """Multi variate Gaussian with a diagonal covariance function (on the last dimension).""" if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") return Independent(Normal(loc, scale_diag), 1)
def get_dist(self, mean, std, dims=1, normal=True): if normal: return Independent(Normal(mean, std), dims) else: return Independent(Bernoulli(logits=mean), dims)
def choose_action(self, state): state = torch.from_numpy(state).float().unsqueeze(0) mean, logstd = self.policy(state) dist = Independent(Normal(mean, torch.exp(logstd)), 1) return dist.sample().squeeze().numpy()
def StandardNormal(d,device=torch.device('cuda:0')): return Independent(Normal(torch.zeros(d).to(device),torch.ones(d).to(device)),1)
class TanhNormal(torch.distributions.Distribution): r"""A distribution induced by applying a tanh transformation to a Gaussian random variable. Algorithms like SAC and Pearl use this transformed distribution. It can be thought of as a distribution of X where :math:`Y ~ \mathcal{N}(\mu, \sigma)` :math:`X = tanh(Y)` Args: loc (torch.Tensor): The mean of this distribution. scale (torch.Tensor): The stdev of this distribution. """ # noqa: 501 def __init__(self, loc, scale): self._normal = Independent(Normal(loc, scale), 1) super().__init__() def log_prob(self, value, pre_tanh_value=None, epsilon=1e-6): """The log likelihood of a sample on the this Tanh Distribution. Args: value (torch.Tensor): The sample whose loglikelihood is being computed. pre_tanh_value (torch.Tensor): The value prior to having the tanh function applied to it but after it has been sampled from the normal distribution. epsilon (float): Regularization constant. Making this value larger makes the computation more stable but less precise. Note: when pre_tanh_value is None, an estimate is made of what the value is. This leads to a worse estimation of the log_prob. If the value being used is collected from functions like `sample` and `rsample`, one can instead use functions like `sample_return_pre_tanh_value` or `rsample_return_pre_tanh_value` Returns: torch.Tensor: The log likelihood of value on the distribution. """ # pylint: disable=arguments-differ if pre_tanh_value is None: pre_tanh_value = torch.log( (1 + epsilon + value) / (1 + epsilon - value)) / 2 norm_lp = self._normal.log_prob(pre_tanh_value) ret = (norm_lp - torch.sum( torch.log(self._clip_but_pass_gradient((1. - value**2)) + epsilon), axis=-1)) return ret def sample(self, sample_shape=torch.Size()): """Return a sample, sampled from this TanhNormal Distribution. Args: sample_shape (list): Shape of the returned value. Note: Gradients `do not` pass through this operation. Returns: torch.Tensor: Sample from this TanhNormal distribution. """ with torch.no_grad(): return self.rsample(sample_shape=sample_shape) def rsample(self, sample_shape=torch.Size()): """Return a sample, sampled from this TanhNormal Distribution. Args: sample_shape (list): Shape of the returned value. Note: Gradients pass through this operation. Returns: torch.Tensor: Sample from this TanhNormal distribution. """ z = self._normal.rsample(sample_shape) return torch.tanh(z) def rsample_with_pre_tanh_value(self, sample_shape=torch.Size()): """Return a sample, sampled from this TanhNormal distribution. Returns the sampled value before the tanh transform is applied and the sampled value with the tanh transform applied to it. Args: sample_shape (list): shape of the return. Note: Gradients pass through this operation. Returns: torch.Tensor: Samples from this distribution. torch.Tensor: Samples from the underlying :obj:`torch.distributions.Normal` distribution, prior to being transformed with `tanh`. """ z = self._normal.rsample(sample_shape) return z, torch.tanh(z) def cdf(self, value): """Returns the CDF at the value. Returns the cumulative density/mass function evaluated at `value` on the underlying normal distribution. Args: value (torch.Tensor): The element where the cdf is being evaluated at. Returns: torch.Tensor: the result of the cdf being computed. """ return self._normal.cdf(value) def icdf(self, value): """Returns the icdf function evaluated at `value`. Returns the icdf function evaluated at `value` on the underlying normal distribution. Args: value (torch.Tensor): The element where the cdf is being evaluated at. Returns: torch.Tensor: the result of the cdf being computed. """ return self._normal.icdf(value) @classmethod def _from_distribution(cls, new_normal): """Construct a new TanhNormal distribution from a normal distribution. Args: new_normal (Independent(Normal)): underlying normal dist for the new TanhNormal distribution. Returns: TanhNormal: A new distribution whose underlying normal dist is new_normal. """ # pylint: disable=protected-access new = cls(torch.zeros(1), torch.zeros(1)) new._normal = new_normal return new def expand(self, batch_shape, _instance=None): """Returns a new TanhNormal distribution. (or populates an existing instance provided by a derived class) with batch dimensions expanded to `batch_shape`. This method calls :class:`~torch.Tensor.expand` on the distribution's parameters. As such, this does not allocate new memory for the expanded distribution instance. Additionally, this does not repeat any args checking or parameter broadcasting in `__init__.py`, when an instance is first created. Args: batch_shape (torch.Size): the desired expanded size. _instance(instance): new instance provided by subclasses that need to override `.expand`. Returns: Instance: New distribution instance with batch dimensions expanded to `batch_size`. """ new_normal = self._normal.expand(batch_shape, _instance) new = self._from_distribution(new_normal) return new def enumerate_support(self, expand=True): """Returns tensor containing all values supported by a discrete dist. The result will enumerate over dimension 0, so the shape of the result will be `(cardinality,) + batch_shape + event_shape` (where `event_shape = ()` for univariate distributions). Note that this enumerates over all batched tensors in lock-step `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens along dim 0, but with the remaining batch dimensions being singleton dimensions, `[[0], [1], ..`. To iterate over the full Cartesian product use `itertools.product(m.enumerate_support())`. Args: expand (bool): whether to expand the support over the batch dims to match the distribution's `batch_shape`. Note: Calls the enumerate_support function of the underlying normal distribution. Returns: torch.Tensor: Tensor iterating over dimension 0. """ return self._normal.enumerate_support(expand) @property def mean(self): """torch.Tensor: mean of the distribution.""" return torch.tanh(self._normal.mean) @property def variance(self): """torch.Tensor: variance of the underlying normal distribution.""" return self._normal.variance def entropy(self): """Returns entropy of the underlying normal distribution. Returns: torch.Tensor: entropy of the underlying normal distribution. """ return self._normal.entropy() @staticmethod def _clip_but_pass_gradient(x, lower=0., upper=1.): """Clipping function that allows for gradients to flow through. Args: x (torch.Tensor): value to be clipped lower (float): lower bound of clipping upper (float): upper bound of clipping Returns: torch.Tensor: x clipped between lower and upper. """ clip_up = (x > upper).float() clip_low = (x < lower).float() with torch.no_grad(): clip = ((upper - x) * clip_up + (lower - x) * clip_low) return x + clip def __repr__(self): """Returns the parameterization of the distribution. Returns: str: The parameterization of the distribution and underlying distribution. """ return self.__class__.__name__
def __init__(self, loc, scale): self._normal = Independent(Normal(loc, scale), 1) super().__init__()
[ torch.tensor(sample[3]).float().unsqueeze(0) for sample in rollout_batch ] ).cuda() old_values = torch.cat( [torch.tensor(sample[4]).unsqueeze(0) for sample in rollout_batch] ).cuda() old_log_probs = torch.cat( [torch.tensor(sample[5]).unsqueeze(0) for sample in rollout_batch] ).cuda() means, logstd = policy(states) dist = Independent( Normal( means, torch.exp(logstd.unsqueeze(0).expand(batch_size, -1)) ), 1, ) log_probs = dist.log_prob(actions.squeeze()) values = value_fn(states) clipped_values = old_values + torch.clamp( values - old_values, -args.clip_range, args.clip_range ) l_vf1 = (values - targets).pow(2) l_vf2 = (clipped_values - targets).pow(2) value_fn_loss = 0.5 * torch.max(l_vf1, l_vf2).mean() value_fn_loss.backward() clip_grad_norm_(value_fn.parameters(), args.max_grad_norm) value_fn_opt.step() value_fn_opt.zero_grad()
def get_act(mean, std, amount=None): dist = Independent(Normal(mean, std), reinterpreted_batch_ndims=1) if amount is not None: return dist.sample(amount) return dist.sample()
def MultivariateNormalDiag(loc, scale_diag): if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") return Independent(Normal(loc, scale_diag), 1)
observation_loss = F.mse_loss( bottle(observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum(dim=2 if args.symbolic else (2, 3, 4)).mean( dim=(0, 1)) reward_loss = F.mse_loss(bottle(reward_model, (beliefs, posterior_states)), rewards[1:], reduction='none').mean(dim=(0, 1)) # transition loss kl_loss = torch.max( kl_divergence( Independent(Normal(posterior_means, posterior_std_devs), 1), Independent(Normal(prior_means, prior_std_devs), 1)), free_nats).mean(dim=(0, 1)) # print("check the reward", bottle(pcont_model, (beliefs, posterior_states)).shape, nonterminals[:-1].shape) if args.pcont: pcont_pred = torch.distributions.Bernoulli( logits=bottle(pcont_model, (beliefs, posterior_states))) # print("check pcont", pcont_pred) # print("nonterminal", nonterminals[1:]) pcont_loss = -pcont_pred.log_prob( nonterminals[1:]).mean(dim=(0, 1)) print(pcont_loss) # Update model parameters world_optimizer.zero_grad() (observation_loss + reward_loss + kl_loss +