def sample(self, obs_with_hidden): mean, log_std, hidden = self.actor.forward(obs_with_hidden) log_std = torch.tanh(log_std) log_std = self.hyperps['log_std_min'] + 0.5 * (self.hyperps['log_std_max'] - self.hyperps['log_std_min']) * (log_std + 1) std = log_std.exp() normal = Normal(mean, std) x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) y_t = torch.tanh(x_t) action = y_t * self.hyperps['action_scale'] + self.hyperps['action_bias'] log_prob = normal.log_prob(x_t) # Enforcing Action Bound log_prob = log_prob - torch.log(self.hyperps['action_scale'] * (1 - y_t.pow(2)) + self.hyperps['epsilon']) log_prob = log_prob.sum(1, keepdim=True) mean = torch.tanh(mean) * self.hyperps['action_scale'] + self.hyperps['action_bias'] return action, log_prob, mean, hidden
def forward(self, batch_size, temp=1.0): first_dist = Normal(self.prior.first_mean, self.prior.first_logvar.exp()) results = [] sample = temp * first_dist.sample(sample_shape=batch_size)[:, 0, :] out = self.bottom_up.first(sample) out = out.view(out.size(0), out.size(1), 1, 1) out = func.interpolate(out, scale_factor=self.scale) for idx, (block, mean, logvar, mf, lf, mod, zero) in enumerate( zip(self.bottom_up.blocks, self.prior.mean, self.prior.logvar, self.prior.mean_factor, self.prior.logvar_factor, self.bottom_up.modifiers, self.bottom_up.zeros)): mf = 1 lf = 1 results.append(out) pos = self.prior.position_embedding(out) dpos = torch.cat((out, pos), dim=1) dist = Normal(mean(dpos) * mf, (logvar(dpos) * lf).exp()) sample = temp * dist.rsample() out = block(out + 0.1 * mod(sample)) if (idx + 1) % self.bottom_up.level_repeat == 0 and idx < len( self.bottom_up.blocks) - 1: out = func.interpolate(out, scale_factor=2) res = DiscretizedMixtureLogits(10, self.decoder.block( results[-1])).sample() res = ((res + 1) / 2).clamp(0, 1) return res
def sample(self, obs, msg): mean, log_std, hidden, msg = self.actor.forward(obs, msg) std = log_std.exp() normal = Normal(mean, std) x_t = normal.rsample( ) # for reparameterization trick (mean + std * N(0,1)) y_t = torch.tanh(x_t) action = y_t * self.hyperps[ 'action_scale'] #+ self.hyperps['action_bias'] action[:, 0] += self.hyperps['action_bias'] log_prob = normal.log_prob(x_t) # Enforcing Action Bound log_prob -= torch.log(self.hyperps['action_scale'] * (1 - y_t.pow(2)) + self.hyperps['epsilon']) log_prob = log_prob.sum(1, keepdim=True) mean = torch.tanh( mean) * self.hyperps['action_scale'] + self.hyperps['action_bias'] entropy = normal.entropy() entropy1, entropy2 = entropy[0][0].item(), entropy[0][1].item() #print('Std: {:2.3f}, {:2.3f}, log_std: {:2.3f},{:2.3f}, entropy:{:2.3f}, {:2.3f}'.format(std[0][0].item(),std[0][1].item(), log_std[0][0].item(), log_std[0][1].item(), entropy1, entropy2)) return action, log_prob, mean, std, hidden, msg
def get_action(self, s): s = torch.tensor(data=s, dtype=torch.float) mean, std = self.actor(s) normal = Normal(mean, std) z = normal.rsample() a = torch.tanh(z) return a.detach().numpy().tolist()
def get_action (self,state, epsilon=1e-6, reparam=True): mean, log_std = self.forward(state) std = log_std.exp() normal = Normal(mean, std) if reparam=True: z = normal.rsample() # reparameterization trick
def predict(self, x) -> dict: """ :param x: tensor of shape [batch_size, num_features] :return: A dictionary containing prediction i.e. - latent_dist = torch.distributions.Normal instance of latent space - latent_mu = torch.Tensor mu (mean) parameter of latent Normal distribution - latent_sigma = torch.Tensor sigma (std) parameter of latent Normal distribution - recon_mu = torch.Tensor mu (mean) parameter of reconstructed Normal distribution - recon_sigma = torch.Tensor sigma (std) parameter of reconstructed Normal distribution - z = torch.Tensor sampled latent space from latent distribution """ batch_size = len(x) latent_mu, latent_sigma = self.encoder(x).chunk( 2, dim=1) #both with size [batch_size, latent_size] latent_sigma = softplus(latent_sigma) dist = Normal(latent_mu, latent_sigma) z = dist.rsample([self.L]) # shape: [L, batch_size, latent_size] z = z.view(self.L * batch_size, self.latent_size) recon_mu, recon_sigma = self.decoder(z).chunk(2, dim=1) recon_sigma = softplus(recon_sigma) recon_mu = recon_mu.view(self.L, *x.shape) recon_sigma = recon_sigma.view(self.L, *x.shape) return dict(latent_dist=dist, latent_mu=latent_mu, latent_sigma=latent_sigma, recon_mu=recon_mu, recon_sigma=recon_sigma, z=z)
def produce_action_and_action_info(self, state, return_stats: bool = False): """Given the state, produces an action, the log probability of the action, and the tanh of the mean action""" if return_stats: actor_output, actor_stats = self.actor_local(state, return_stats) else: actor_output = self.actor_local(state) mean, log_std = actor_output[:, :self. action_size], actor_output[:, self. action_size:] std = log_std.exp() normal = Normal(mean, std) x_t = normal.rsample( ) # rsample means it is sampled using reparameterisation trick action = torch.tanh(x_t) log_prob = normal.log_prob(x_t) log_prob -= torch.log(1 - action.pow(2) + EPSILON) log_prob = log_prob.sum(1, keepdim=True) if return_stats: actor_stats['action'] = { 'mean': mean.detach().cpu().numpy(), 'std': std.detach().cpu().numpy(), } return action, log_prob, torch.tanh(mean), actor_stats else: return action, log_prob, torch.tanh(mean)
def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, ...]]: n = int(torch.randint(1, 5, (1, )).item()) # perm = torch.argsort(torch.rand(x.size(0), x.size(0), device=x.device), dim=1) # subsets = perm[:, :n] d = ((x.unsqueeze(0) - x.unsqueeze(1))**2).sum(dim=2) subsets = torch.argsort(d, dim=1)[:, :n] # create the subset index which will be used to get the subsets distance from the pairwise distance # matrix. linspace is needed because it needs to be a tuple of (x, y) coordinates # fmt: off sub_idx = ( torch.linspace(0, x.size(0) - 1, x.size(0), device=x.device).repeat( n, 1).T.flatten().long(), # type: ignore subsets.flatten().long()) # fmt: on # print(f"subsets: {subsets.size()} sub idx: {sub_idx.size()}") z = x[subsets] mu = z.mean(dim=1) mx, _ = z.max(dim=1) z = torch.cat((mu, mx), dim=1) z = self.decoder(z) dist = Normal(z[:, :self.in_dim], torch.exp(z[:, self.in_dim:] / 2)) x = x + dist.rsample() return x, dist.entropy().mean(), sub_idx
def normal_tanh_reparameterised_sample(dis: Normal, epsilon=1e-6 ) -> Tuple[torch.tensor, torch.tensor]: """ The log-likelihood here is for the TanhNorm distribution instead of only Gaussian distribution. The TanhNorm forces the Gaussian with infinite action range to be finite. For the three terms in this log-likelihood estimation: (1). the first term is the log probability of action as in common stochastic Gaussian action policy (without Tanh); \ (2). the second term is the caused by the Tanh(), as shown in appendix C. Enforcing Action Bounds of https://arxiv.org/pdf/1801.01290.pdf, the epsilon is for preventing the negative cases in log @param dis: @param epsilon: @return: """ z = dis.rsample() # for reparameterisation trick (mean + std * N(0,1)) action = torch.tanh(z) log_prob = torch.sum(dis.log_prob(z) - torch.log(1 - action.pow(2) + epsilon), dim=-1, keepdim=True) return action, log_prob
class TanhNormal(Distribution): """Distribution of X ~ tanh(Z) where Z ~ N(mean, std) Adapted from https://github.com/vitchyr/rlkit """ def __init__(self, normal_mean, normal_std, epsilon=1e-6): self.normal_mean = normal_mean self.normal_std = normal_std self.normal = Normal(normal_mean, normal_std) self.epsilon = epsilon super().__init__(self.normal.batch_shape, self.normal.event_shape) def log_prob(self, x): assert hasattr(x, "pre_tanh_value") assert x.dim() == 2 and x.pre_tanh_value.dim() == 2 return self.normal.log_prob(x.pre_tanh_value) - torch.log( 1 - x * x + self.epsilon ) def sample(self, sample_shape=torch.Size()): z = self.normal.sample(sample_shape) out = torch.tanh(z) out.pre_tanh_value = z return out def rsample(self, sample_shape=torch.Size()): z = self.normal.rsample(sample_shape) out = torch.tanh(z) out.pre_tanh_value = z return out
def forward(self, x): # Reshape data for net if len(x.shape)==4: batch, chan, h, w = x.shape x = x.view(batch,chan,h*w).squeeze(1) x = F.relu(self.linear1(x)) x = F.relu(self.linear2(x)) mean = self.mean_linear(x) log_std = self.log_std_linear(x) log_std = torch.clamp(log_std, min=self.LOG_SIG_MIN, max=self.LOG_SIG_MAX) std = log_std.exp() normal = Normal(mean, std) delta = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) log_prob = normal.log_prob(delta) if self.decode: delta = self.linear(delta) # Problem here is log_prob can't be for delta unless same size... # Enforcing Action Bound # log_prob -= torch.log(1 - action.pow(2) + self.epsilon) # log_prob = log_prob.sum(-1, keepdim=True) # Shape noise to match original data delta = delta.unsqueeze(1) delta = delta.view(batch,chan,h,w) return delta, mean, log_std, log_prob
def forward(self, x, reparam=True): x = F.relu(self.l1(x)) x = F.relu(self.l2(x)) action = self.max_action * torch.tanh(self.l3(x)) mean = self.mean_linear(x) log_std = self.log_std_linear(x) log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) std = log_std.exp() normal = Normal(mean, std) if reparam == True: x_t = normal.rsample() else: x_t = normal.sample() log_prob = normal.log_prob(x_t) log_prob -= torch.log(1 - action.pow(2) + epsilon) log_prob = log_prob.sum(-1, keepdim=True) entropy = normal.entropy() dist_entropy = entropy.sum(-1).mean() return action, dist_entropy, mean, log_std, log_prob
def _sample_z(self, mean, var, name, cell_coord): ''' Performs the sampling step in VAE and stores the distribution for KL computation :param mean: :param var: :param name: name of the distribution :return: sampled value ''' dist = Normal(loc=mean, scale=var) if name not in self.dist_param.keys(): _, H, W = self.feature_space_dim self.dist_param[name] = {} self.dist_param[name]['mean'] = torch.empty( self.batch_size, mean.shape[-1], H, W, ).to(self.device) self.dist_param[name]['sigma'] = torch.empty( self.batch_size, mean.shape[-1], H, W, ).to(self.device) x, y = cell_coord self.dist_param[name]['mean'][:, :, x, y] = mean self.dist_param[name]['sigma'][:, :, x, y] = var return dist.rsample()
def forward(self, x): # Sample the weights and forward it # perform all operations in the forward rather than in __init__ (including log[1+exp(rho)]) variational_posterior = Normal(self.mu, torch.log1p(torch.exp(self.rho))) variational_posterior_bias = Normal( self.mu_bias, torch.log1p(torch.exp(self.rho_bias))) w = variational_posterior.rsample() b = variational_posterior_bias.rsample() # Get the log prob self.log_variational_posterior = (variational_posterior.log_prob( w)).sum() + (variational_posterior_bias.log_prob(b)).sum() self.log_prior = self.prior_weights.log_prob( w).sum() + self.prior_bias.log_prob(b).sum() return F.linear(x, w, b)
def sample(self, state): """ Samples actions and log actions from the distribution Arguments: state : State vector containing state variables """ epsilon = 1e-6 mu, log_std = self.forward(state) std = log_std.exp() #Gaussian distribution with mu and std from the network normal = Normal(mu, std) #Action is sampled from the distribution z = normal.rsample() #Tanh sqeezes the action between (-1,1) action = torch.tanh(z) """ Log probability for the action is calculated using log-likelihood formula Refer equation 21 : https://arxiv.org/pdf/1801.01290.pdf """ log_prob = (normal.log_prob(z) - torch.log(1 - (torch.tanh(z)).pow(2) + epsilon)) log_prob = log_prob.sum(1, keepdims=True) return action, log_prob
def forward(self, x): logN = torch.log(x.sum(axis=-1)).view(-1, 1) varz = torch.stack([self.variational_logvars] * len(logN)) varz = torch.cat((varz, logN), dim=1) z_var = self.sigma_net(varz) z_mean = self.encode(x) qz = Normal(z_mean, torch.exp(0.5 * z_var)) ql = Normal(0, torch.exp(0.5 * self.log_sigma_sq)) z_sample = qz.rsample() l_sample = ql.rsample() x_out = self.decoder(z_sample) + l_sample kl_div = kl_divergence(qz, Normal(0, 1)).mean(0).sum() recon_loss = self.recon_model_loglik(x, x_out).mean(0).sum() elbo = recon_loss - kl_div loss = - elbo return loss
def get_z_sup_sample(self, zp_mean, zp_std): """Get z sample and log_lik of sample from state code. Args: zp_mean, zp_std (torch.Tensor), 2 * (nTo, 4): State dist. parameters. Returns: z_obj (torch.Tensor), (nTo, 4): Sampled states. log_q_xz (torch.Tensor), (nTo): Likelihood of samples for ELBO. """ # get z from sampling, each gaussian has dim (4) # we need n4o samples per gaussian. dim of sampling is again n4o, 4 z_dist = Normal(zp_mean, zp_std) # rsample can propagate gradients, no explicit reparametrization z_tmp = z_dist.rsample().to(self.c.device) # Get log lik of sample # approximated E_q(z|x) [log q(z|x)] with single sample # Sum (in log-domain) the probabilities for the z's [per image] # sum (n4o, 4) to n4o log_q_xz = z_dist.log_prob(z_tmp).sum(-1) # Get sy from sy/sx sy. z_obj = self.sy_from_quotient(z_tmp) return z_obj, log_q_xz
def sample_conditional_a(self, resid_image, var_so_far, pixel_1d): is_on = (pixel_1d < (self.n_discrete_latent - 1)).float() # pass through galaxy encoder pixel_2d = self.one_galaxy_vae.pixel_1d_to_2d(pixel_1d) z_mean, z_var = self.one_galaxy_vae.enc(resid_image, pixel_2d) # sample z q_z = Normal(z_mean, z_var.sqrt()) z_sample = q_z.rsample() # kl term for continuous latent vars log_q_z = q_z.log_prob(z_sample).sum(1) p_z = Normal(torch.zeros_like(z_sample), torch.ones_like(z_sample)) log_p_z = p_z.log_prob(z_sample).sum(1) kl_z = is_on * (log_q_z - log_p_z) # run through decoder recon_mean, recon_var = self.one_galaxy_vae.dec(is_on, pixel_2d, z_sample) # NOTE: we will have to the recon means once we do more detections # recon_means = recon_mean + image_so_far # recon_vars = recon_var + var_so_far return recon_mean, recon_var, is_on, kl_z
def forward(self, state: torch.Tensor) -> torch.Tensor: """Forward method implementation.""" x = self.relu(self.bn1(self.conv1(state))) x = self.relu(self.bn2(self.conv2(x))) x = self.relu(self.bn3(self.conv3(x))) x = self.lrelu(self.fc1(x.view(x.size(0), -1))) x = self.lrelu(self.fc2(x)) x = self.lrelu(self.fc3(x)) # get mean mu = self.mu_layer(x).tanh() # get std log_std = self.log_std_layer(x).tanh() log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1) std = torch.exp(log_std) # sample actions dist = Normal(mu, std) z = dist.rsample() # normalize action and log_prob # see appendix C of [2] action = z.tanh() log_prob = dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-7) log_prob = log_prob.sum(-1, keepdim=True) return action, log_prob
def sample_action(self, state: np.ndarray, deterministic: bool = False) -> np.ndarray: """ sample action normal distribution parameterized by policy network :param state: Observation state :param deterministic: Is the greedy action being chosen? :type state: int, float, ... :type deterministic: bool :returns: action :returns: log likelihood of policy :returns: scaled mean of normal distribution :rtype: int, float, ... :rtype: float :rtype: float """ mean, log_std = self.policy.forward(state) std = log_std.exp() # reparameterization trick distribution = Normal(mean, std) xi = distribution.rsample() yi = torch.tanh(xi) action = yi * self.action_scale + self.action_bias log_pi = distribution.log_prob(xi) # enforcing action bound (appendix of paper) log_pi -= torch.log(self.action_scale * (1 - yi.pow(2)) + np.finfo(np.float32).eps) log_pi = log_pi.sum(1, keepdim=True) mean = torch.tanh(mean) * self.action_scale + self.action_bias return action.float(), log_pi, mean
def reparameters(self, means, stds): distribution = Normal(means, stds) actions = distribution.rsample() news_actions = torch.tanh(actions) log_probs = distribution.log_prob(actions.detach()).sum(dim=1, keepdims=True) return news_actions, log_probs
def forward(self, x, n_samples, squeeze=True, reparam=True): q = self.encoder(x) q_m = self.mean_encoder(q) q_v = self.var_encoder(q) # q_v = 16.0 * self.tanh(q_v) # q_v = torch.clamp(q_v, min=-17., max=14.) # PREVIOUS TO KEEP # q_m = torch.clamp(q_m, min=-1000, max=1000) # q_v = torch.clamp(q_v, min=-17.0, max=8.0) q_v = torch.clamp(q_v, min=-17.0, max=10.0) q_v = q_v.exp() # q_v = 1e-16 + q_v.exp() variational_dist = Normal(loc=q_m, scale=q_v.sqrt()) if n_samples == 1 and squeeze: sample_shape = [] else: sample_shape = (n_samples, ) if reparam: latent = variational_dist.rsample(sample_shape=sample_shape) else: latent = variational_dist.sample(sample_shape=sample_shape) return dict(q_m=q_m, q_v=q_v, latent=latent)
def forward(self, input: Tensor) -> Tensor: params_ = self.conv(input) mu = params_[..., :128] std = params_[..., 128:] n = Normal(mu, std) z = n.rsample() # latent variable return z
def forward(self, state): """ Given states input [batch, state_dim], """ state = state.to(self.device) # x = F.relu(self.linear1(state)) # x = F.relu(self.linear2(x)) x = self.linear(state) mean = self.mean_linear(x) # return torch.tanh(mean), 0 log_std = self.log_std_linear(x) # log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) std = torch.exp(log_std) std = torch.clamp(std, self.std_min, self.std_max) normal = Normal(mean, std) z = normal.rsample() a = torch.tanh(z) # compute log probability log_pi = normal.log_prob(z) - torch.log(1 - a.pow(2) + 1e-6) return a, log_pi
def get_action(self, state: torch.Tensor, deterministic: bool = False): state = torch.as_tensor(state).float() if self.actor.sac: mean, log_std = self.actor(state) std = log_std.exp() distribution = Normal(mean, std) action_probs = distribution.rsample() log_probs = distribution.log_prob(action_probs) action_probs = torch.tanh(action_probs) action = action_probs * self.action_scale + self.action_bias # enforcing action bound (appendix of SAC paper) log_probs -= torch.log( self.action_scale * (1 - action_probs.pow(2)) + np.finfo(np.float32).eps ) log_probs = log_probs.sum(1, keepdim=True) mean = torch.tanh(mean) * self.action_scale + self.action_bias action = (action.float(), log_probs, mean) else: action = self.actor.get_action(state, deterministic=deterministic) return action
def sample(self, state): ''' :param state: (batch_num, state_dim) :return: action: (batch_num, action_dim, option_num) log_prob: (batch_num, option_num) mean_mat: (batch_num, action_dim, option_num) ''' mean_mat, log_std_mat = self.forward(state) std_mat = log_std_mat.exp() normal = Normal(mean_mat, std_mat) x_t = normal.rsample( ) # for reparameterization trick (mean + std * N(0,1)) # print('x_t', x_t.shape) y_t = torch.tanh(x_t) action = y_t * self.action_scale + self.action_bias log_prob = normal.log_prob(x_t) # log(pi(at|st)) # Enforcing Action Bound, because the Gaussian distribution changes from (-inf, inf) to (-1, 1) log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) # print('log_prob', log_prob.shape) log_prob = log_prob.sum(1, keepdim=True) # print('log_prob_sum', log_prob.shape) mean_mat = torch.tanh(mean_mat) * self.action_scale + self.action_bias return action, log_prob, mean_mat
def forward(self, x, a = None): x = self.net(x) mu = self.mu(x) log_sigma = self.log_sigma(x) """ Note from Josh Achiam @ OpenAI Because algorithm maximizes trade-off of reward and entropy, entropy must be unique to state---and therefore log_stds need to be a neural network output instead of a shared-across-states learnable parameter vector. But for deep Relu and other nets, simply sticking an activationless dense layer at the end would be quite bad---at the beginning of training, a randomly initialized net could produce extremely large values for the log_stds, which would result in some actions being either entirely deterministic or too random to come back to earth. Either of these introduces numerical instability which could break the algorithm. To protect against that, we'll constrain the output range of the log_stds, to lie within [LOG_STD_MIN, LOG_STD_MAX]. This is slightly different from the trick used by the original authors of SAC---they used tf.clip_by_value instead of squashing and rescaling. I prefer this approach because it allows gradient propagation through log_std where clipping wouldn't, but I don't know if it makes much of a difference. """ log_sigma = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_sigma + 1) sigma = torch.exp(log_sigma) dist = Normal(mu, sigma) # rsample() - https://pytorch.org/docs/stable/distributions.html#pathwise-derivative pi = dist.rsample() # reparametrization logp_pi = dist.log_prob(pi).sum(dim=1) mu *= self.act_limit pi *= self.act_limit mu, pi, logp_pi = apply_squashing_func(mu, pi, logp_pi) return mu, pi, logp_pi
def noisy_action(self, state, return_only_action=True): if self.policy_type == 'GaussianPolicy': mean, log_std = self.clean_action(state, return_only_action=False) std = log_std.exp() normal = Normal(mean, std) x_t = normal.rsample( ) # for reparameterization trick (mean + std * N(0,1)) action = torch.tanh(x_t) if return_only_action: return action log_prob = normal.log_prob(x_t) # Enforcing Action Bound log_prob -= torch.log(1 - action.pow(2) + epsilon) log_prob = log_prob.sum(-1, keepdim=True) #log_prob.clamp(-10, 0) return action, log_prob, x_t, mean, log_std elif self.policy_type == 'DeterministicPolicy': mean = self.clean_action(state) action = mean + self.noise.normal_(0., std=0.4) if return_only_action: return action else: return action, torch.tensor(0.), torch.tensor( 0.), mean, torch.tensor(0.)
def forward(self, x): x = super(ReparamGaussianPolicy, self).forward(x) mu = self.mu_layer(x) log_std = torch.tanh(self.log_std_layer(x)) log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) std = torch.exp(log_std) # https://pytorch.org/docs/stable/distributions.html#normal dist = Normal(mu, std) pi = dist.rsample() # reparameterization trick (mean + std * N(0,1)) log_pi = dist.log_prob(pi).sum(dim=-1) mu, pi, log_pi = self.apply_squashing_func(mu, pi, log_pi) if self.log_type == 'log': # make sure actions are in correct range mu = mu * self.action_scale pi = pi * self.action_scale return mu, pi, log_pi elif self.log_type == 'log-q': if self.q == 1.: log_q_pi = log_pi else: exp_log_pi = torch.exp(log_pi) log_q_pi = self.tsallis_entropy_log_q(exp_log_pi, self.q) # make sure actions are in correct range mu = mu * self.action_scale pi = pi * self.action_scale return mu, pi, log_q_pi
def select_action(self, state, deterministic=False): """ Compute an action or vector of actions given a state or vector of states :param state: the input state(s) :param deterministic: whether the policy should be considered deterministic or not :return: the resulting action(s) """ with torch.no_grad(): # Forward pass mu, std = self.forward(state) pi_distribution = Normal(mu, std) if deterministic: # Only used for evaluating policy at test time. pi_action = mu else: pi_action = pi_distribution.rsample() # Finally applies tanh for squashing #If env is Pendulum: pi_action = 2 * torch.tanh(pi_action) # pi_action = torch.tanh(pi_action) if len(pi_action) == 1: pi_action = pi_action[0] return pi_action.data.numpy().astype(float)
def forward(self, sample_shape=()): if isinstance(self.prior, FactorisedPrior): sqrt_prec = 1. / math.sqrt(self.in_features) post_mean = self.post_mean * sqrt_prec post_log_var = self.post_log_var_scaled * self.log_var_lr + 2. * math.log( sqrt_prec) prior_prec = self.prior(1) KL_term = 0.5*((post_mean**2).sum() + post_log_var.exp().sum())*prior_prec.scale -\ 0.5*post_mean.numel() - 0.5*post_mean.numel()*t.log(prior_prec.scale) - 0.5*post_log_var.sum() self.logpq = -KL_term * t.ones(*sample_shape, device=KL_term.device) return post_mean, post_log_var.exp() else: post_log_var = self.post_log_var_scaled * self.log_var_lr sqrt_prec = 1. / math.sqrt(self.in_features) post_mean = self.post_mean * sqrt_prec Qw = Normal(post_mean, sqrt_prec * t.exp(0.5 * post_log_var)) w = Qw.rsample(sample_shape=t.Size([sample_shape[0]])) prior_prec = self.prior(sample_shape[0]) logP = mvnormal_log_prob(prior_prec, w.transpose(-1, -2)) logQ = Qw.log_prob(w).sum((-1, -2)) self.logpq = logP - logQ return post_mean, post_log_var.exp()
def forward(self, x_src): # Example variational parameters lambda mu, logvar = self.encoder(x_src) q_normal = Normal(loc=mu, scale=logvar.mul(0.5).exp()) # Reparameterized sample. z_sample = q_normal.rsample() # z_sample = mu (no sampling) return self.decoder(z_sample), q_normal
# TODO: to make this stochastic, shuffle and make smaller batches. start = time.time() theta.train() for epoch in range(args.num_epochs*2): # Keep track of reconstruction loss and total kl total_recon_loss = 0 total_kl = 0 total = 0 for img, _ in loader: # no need to Variable(img).cuda() optim1.zero_grad() optim2.zero_grad() q = Normal(loc=mu, scale=logvar.mul(0.5).exp()) # Reparameterized sample. qsamp = q.rsample() kl = kl_divergence(q, p).sum() # KL term out = theta(qsamp) recon_loss = criterion(out, img) # reconstruction term loss = (recon_loss + args.alpha * kl) / args.batch_size total_recon_loss += recon_loss.item() / args.batch_size total_kl += kl.item() / args.batch_size total += 1 loss.backward() if args.clip: torch.nn.utils.clip_grad_norm(theta.parameters(), args.clip) torch.nn.utils.clip_grad_norm(mu, args.clip) torch.nn.utils.clip_grad_norm(theta.parameters(), args.clip) if epoch % 2: optim1.step() wv = 'Theta'