def beta_dist_loss(self, advantage_mask, phi, action, dist_info, old_dist_info, valid, opt_info): action = (action + 1) / 2 distribution = torch.distributions.beta.Beta(dist_info.mean, dist_info.log_std) old_dist = torch.distributions.beta.Beta(old_dist_info.mean, old_dist_info.log_std) pi_loss = -torch.sum( advantage_mask * (phi.detach() * distribution.log_prob(action).sum(dim=-1))) kl = torch.distributions.kl_divergence(old_dist, distribution).sum(dim=-1) alpha_loss = valid_mean( self.alpha * (self.epsilon_alpha - kl.detach()) + self.alpha.detach() * kl, valid) entropy = valid_mean(distribution.entropy().sum(dim=-1), valid) alpha_loss -= 0.01 * entropy mode = self.agent.beta_dist_mode(old_dist_info.mean, old_dist_info.log_std) opt_info.alpha.append(self.alpha.item()) opt_info.policy_kl.append(kl.mean().item()) opt_info.pi_mu.append(mode.mean().item()) opt_info.pi_log_std.append( old_dist.entropy().sum(dim=-1).mean().item()) return pi_loss, alpha_loss, opt_info
def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info): """ Compute the training loss: policy_loss + value_loss + entropy_loss Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage) Value loss: 0.5 * (estimated_value - return) ^ 2 Calls the agent to compute forward pass on training data, and uses the ``agent.distribution`` to compute likelihoods and entropies. Valid for feedforward or recurrent agents. """ dist_info, value = self.agent(*agent_inputs) dist = self.agent.distribution ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip, 1. + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = - valid_mean(surrogate, valid) value_error = 0.5 * (value - return_) ** 2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = - self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss # + self.vae_loss_coeff * vae_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def compute_loss(self, observations, next_observations, actions, valid): #------------------------------------------------------------# # hacky dimension add for when you have only one environment (debugging) if actions.dim() == 2: actions = actions.unsqueeze(1) #------------------------------------------------------------# phi1, phi2, predicted_phi2, predicted_phi2_stacked, predicted_action = self.forward(observations, next_observations, actions) actions = torch.max(actions.view(-1, *actions.shape[2:]), 1)[1] # conver action to (T * B, action_size), then get target indexes inverse_loss = nn.functional.cross_entropy(predicted_action.view(-1, *predicted_action.shape[2:]), actions.detach(), reduction='none').view(phi1.shape[0], phi2.shape[1]) inverse_loss = valid_mean(inverse_loss, valid) forward_loss = torch.tensor(0.0, device=self.device) forward_loss_1 = nn.functional.dropout(nn.functional.mse_loss(predicted_phi2[0], phi2.detach(), reduction='none'), p=0.2).sum(-1)/self.feature_size forward_loss += valid_mean(forward_loss_1, valid) forward_loss_2 = nn.functional.dropout(nn.functional.mse_loss(predicted_phi2[1], phi2.detach(), reduction='none'), p=0.2).sum(-1)/self.feature_size forward_loss += valid_mean(forward_loss_2, valid) forward_loss_3 = nn.functional.dropout(nn.functional.mse_loss(predicted_phi2[2], phi2.detach(), reduction='none'), p=0.2).sum(-1)/self.feature_size forward_loss += valid_mean(forward_loss_3, valid) forward_loss_4 = nn.functional.dropout(nn.functional.mse_loss(predicted_phi2[3], phi2.detach(), reduction='none'), p=0.2).sum(-1)/self.feature_size forward_loss += valid_mean(forward_loss_4, valid) return self.inverse_loss_wt*inverse_loss, self.forward_loss_wt*forward_loss
def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info, init_rnn_state=None): if init_rnn_state is not None: init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: dist_info, value = self.agent(*agent_inputs) dist = self.agent.distribution lr = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) kl = dist.kl(old_dist_info=old_dist_info, new_dist_info=dist_info) if init_rnn_state is not None: raise NotImplementedError else: mean_kl = valid_mean(kl) surr_loss = -valid_mean(lr * advantage) loss = surr_loss entropy = dist.mean_entropy(dist_info, valid) perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def pi_alpha_loss(self, samples, valid, conv_out): # PI LOSS. # Uses detached conv out; avoid re-computing. conv_detach = conv_out.detach() agent_inputs = samples.agent_inputs._replace(observation=conv_detach) new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) if not self.reparameterize: # new_action = new_action.detach() # No grad. raise NotImplementedError # Re-use the detached latent. log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) if self.reparameterize: pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi else: raise NotImplementedError # if self.policy_output_regularization > 0: # pi_losses += self.policy_output_regularization * torch.mean( # 0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1) pi_loss = valid_mean(pi_losses, valid) # ALPHA LOSS. if self.target_entropy is not None: alpha_losses = -self._log_alpha * (log_pi.detach() + self.target_entropy) alpha_loss = valid_mean(alpha_losses, valid) else: alpha_loss = None return pi_loss, alpha_loss, pi_mean.detach(), pi_log_std.detach()
def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info, init_rnn_state=None): if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: dist_info, value = self.agent(*agent_inputs) dist = self.agent.distribution ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip, 1. + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = - valid_mean(surrogate, valid) value_error = 0.5 * (value - return_) ** 2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = - self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def loss(self, samples): agent_inputs = AgentInputs( observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) if self.agent.recurrent: init_rnn_state = self.samples.agent.agent_info.prev_rnn_state[ 0] # T = 0. # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: dist_info, value = self.agent(*agent_inputs) # TODO: try to compute everyone on device. return_, advantage, valid = self.process_returns(samples) dist = self.agent.distribution logli = dist.log_likelihood(samples.agent.action, dist_info) pi_loss = -valid_mean(logli * advantage, valid) value_error = 0.5 * (value - return_)**2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = -self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def loss(self, samples): """ Computes losses for twin Q-values against the min of twin target Q-values and an entropy term. Computes reparameterized policy loss, and loss for tuning entropy weighting, alpha. Input samples have leading batch dimension [B,..] (but not time). """ agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action)) if self.mid_batch_reset and not self.agent.recurrent: valid = torch.ones_like(samples.done, dtype=torch.float) # or None else: valid = valid_from_done(samples.done) if self.bootstrap_timelimit: # To avoid non-use of bootstrap when environment is 'done' due to # time-limit, turn off training on these samples. valid *= (1 - samples.timeout_n.float()) q1, q2 = self.agent.q(*agent_inputs, action) with torch.no_grad(): target_action, target_log_pi, _ = self.agent.pi(*target_inputs) target_q1, target_q2 = self.agent.target_q(*target_inputs, target_action) min_target_q = torch.min(target_q1, target_q2) target_value = min_target_q - self._alpha * target_log_pi disc = self.discount ** self.n_step_return y = (self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_value) q1_loss = 0.5 * valid_mean((y - q1) ** 2, valid) q2_loss = 0.5 * valid_mean((y - q2) ** 2, valid) new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) if not self.reparameterize: new_action = new_action.detach() # No grad. log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) if self.reparameterize: pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi else: raise NotImplementedError # if self.policy_output_regularization > 0: # pi_losses += self.policy_output_regularization * torch.mean( # 0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1) pi_loss = valid_mean(pi_losses, valid) if self.target_entropy is not None and self.fixed_alpha is None: alpha_losses = - self._log_alpha * (log_pi.detach() + self.target_entropy) alpha_loss = valid_mean(alpha_losses, valid) else: alpha_loss = None losses = (q1_loss, q2_loss, pi_loss, alpha_loss) values = tuple(val.detach() for val in (q1, q2, pi_mean, pi_log_std)) return losses, values
def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info, old_value, init_rnn_state=None): """ Compute the training loss: policy_loss + value_loss + entropy_loss Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage) Value loss: 0.5 * (estimated_value - return) ^ 2 Calls the agent to compute forward pass on training data, and uses the ``agent.distribution`` to compute likelihoods and entropies. Valid for feedforward or recurrent agents. """ if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state, device=action.device) else: dist_info, value = self.agent(*agent_inputs, device=action.device) dist = self.agent.distribution # Surrogate policy loss ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip, 1. + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = -valid_mean(surrogate, valid) # Surrogate value loss (if doing) if self.clip_vf_loss: v_loss_unclipped = (value - return_)**2 v_clipped = old_value + torch.clamp( value - old_value, -self.ratio_clip, self.ratio_clip) v_loss_clipped = (v_clipped - return_)**2 v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) value_error = 0.5 * v_loss_max.mean() else: value_error = 0.5 * (value - return_)**2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = -self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, pi_loss, value_loss, entropy, perplexity
def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info, init_rnn_state=None): """ Compute the training loss: policy_loss + value_loss + entropy_loss Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage) Value loss: 0.5 * (estimated_value - return) ^ 2 Calls the agent to compute forward pass on training data, and uses the ``agent.distribution`` to compute likelihoods and entropies. Valid for feedforward or recurrent agents. """ if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: if self.agent.both_actions: dist_info, og_dist_info, value, og_value = self.agent( *agent_inputs) else: dist_info, value = self.agent(*agent_inputs) dist = self.agent.distribution ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip, 1. + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = -valid_mean(surrogate, valid) value_error = 0.5 * (value - return_)**2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = -self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss if self.similarity_loss: # Try KL next # pi_sim = self.agent.distribution.kl(og_dist_info, dist_info) pi_sim = F.cosine_similarity(dist_info.prob, og_dist_info.prob) value_sim = (value - og_value)**2 loss += -self.similarity_coeff * pi_sim.mean() + 0.5 * value_sim.mean() # loss += self.similarity_coeff * (pi_sim.mean() + value_sim) perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def loss(self, samples): """Samples have leading batch dimension [B,..] (but not time).""" agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action)) q1, q2 = self.agent.q(*agent_inputs, action) with torch.no_grad(): target_v = self.agent.target_v(*target_inputs) disc = self.discount**self.n_step_return y = (self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_v) if self.mid_batch_reset and not self.agent.recurrent: valid = torch.ones_like(samples.done, dtype=torch.float) else: valid = valid_from_done(samples.done) if self.bootstrap_timelimit: # To avoid non-use of bootstrap when environment is 'done' due to # time-limit, turn off training on these samples. valid *= (1 - samples.timeout_n.float()) q1_loss = 0.5 * valid_mean((y - q1)**2, valid) q2_loss = 0.5 * valid_mean((y - q2)**2, valid) v = self.agent.v(*agent_inputs) new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) if not self.reparameterize: new_action = new_action.detach() # No grad. log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) v_target = (min_log_target - log_pi + prior_log_pi).detach() # No grad. v_loss = 0.5 * valid_mean((v - v_target)**2, valid) if self.reparameterize: pi_losses = log_pi - min_log_target else: pi_factor = (v - v_target).detach() pi_losses = log_pi * pi_factor if self.policy_output_regularization > 0: pi_losses += self.policy_output_regularization * torch.mean( 0.5 * pi_mean**2 + 0.5 * pi_log_std**2, dim=-1) pi_loss = valid_mean(pi_losses, valid) losses = (q1_loss, q2_loss, v_loss, pi_loss) values = tuple(val.detach() for val in (q1, q2, v, pi_mean, pi_log_std)) return losses, values
def loss( self, agent_inputs, action, return_, advantage, valid, old_dist_info, init_rnn_state=None, ): """ Compute the training loss: policy_loss + value_loss + entropy_loss Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage) Value loss: 0.5 * (estimated_value - return) ^ 2 Calls the agent to compute forward pass on training data, and uses the ``agent.distribution`` to compute likelihoods and entropies. Valid for feedforward or recurrent agents. """ if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: dist_info, value = self.agent(*agent_inputs) dist = self.agent.distribution ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) ratio = ratio.clamp_max(1000) # added (to prevent ratio == inf) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1.0 - self.ratio_clip, 1.0 + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = -valid_mean(surrogate, valid) value_error = 0.5 * (value - return_)**2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = -self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def loss(self, samples): """Samples have leading batch dimension [B,..] (but not time).""" qs = self.agent(*samples.agent_inputs) q = select_at_indexes(samples.action, qs) with torch.no_grad(): target_qs = self.agent.target(*samples.target_inputs) if self.double_dqn: next_qs = self.agent(*samples.target_inputs) next_a = torch.argmax(next_qs, dim=-1) target_q = select_at_indexes(next_a, target_qs) else: target_q = torch.max(target_qs, dim=-1).values disc_target_q = (self.discount**self.n_step_return) * target_q y = samples.return_ + (1 - samples.done_n.float()) * disc_target_q delta = y - q losses = 0.5 * delta**2 abs_delta = abs(delta) if self.delta_clip is not None: # Huber loss. b = self.delta_clip * (abs_delta - self.delta_clip / 2) losses = torch.where(abs_delta <= self.delta_clip, losses, b) if self.prioritized_replay: losses *= samples.is_weights td_abs_errors = torch.clamp(abs_delta.detach(), 0, self.delta_clip) if not self.mid_batch_reset: valid = valid_from_done(samples.done) loss = valid_mean(losses, valid) td_abs_errors *= valid else: loss = torch.mean(losses) return loss, td_abs_errors
def loss(self, samples): """ Computes the Distributional Q-learning loss, based on projecting the discounted rewards + target Q-distribution into the current Q-domain, with cross-entropy loss. Returns loss and KL-divergence-errors for use in prioritization. """ delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1) z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms) # Makde 2-D tensor of contracted z_domain for each data point, # with zeros where next value should not be added. next_z = z * (self.discount ** self.n_step_return) # [P'] next_z = torch.ger(1 - samples.done_n.float(), next_z) # [B,P'] ret = samples.return_.unsqueeze(1) # [B,1] next_z = torch.clamp(ret + next_z, self.V_min, self.V_max) # [B,P'] z_bc = z.view(1, -1, 1) # [1,P,1] next_z_bc = next_z.unsqueeze(1) # [B,1,P'] abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1) # Most 0. # projection_coeffs is a 3-D tensor: [B,P,P'] # dim-0: independent data entries # dim-1: base_z atoms (remains after projection) # dim-2: next_z atoms (summed in projection) with torch.no_grad(): target_ps = self.agent.target(*samples.target_inputs) # [B,A,P'] if self.double_dqn: next_ps = self.agent(*samples.target_inputs) # [B,A,P'] next_qs = torch.tensordot(next_ps, z, dims=1) # [B,A] next_a = torch.argmax(next_qs, dim=-1) # [B] else: target_qs = torch.tensordot(target_ps, z, dims=1) # [B,A] next_a = torch.argmax(target_qs, dim=-1) # [B] target_p_unproj = select_at_indexes(next_a, target_ps) # [B,P'] target_p_unproj = target_p_unproj.unsqueeze(1) # [B,1,P'] target_p = (target_p_unproj * projection_coeffs).sum(-1) # [B,P] ps = self.agent(*samples.agent_inputs) # [B,A,P] p = select_at_indexes(samples.action, ps) # [B,P] p = torch.clamp(p, EPS, 1) # NaN-guard. losses = -torch.sum(target_p * torch.log(p), dim=1) # Cross-entropy. if self.prioritized_replay: losses *= samples.is_weights target_p = torch.clamp(target_p, EPS, 1) KL_div = torch.sum(target_p * (torch.log(target_p) - torch.log(p.detach())), dim=1) KL_div = torch.clamp(KL_div, EPS, 1 / EPS) # Avoid <0 from NaN-guard. if not self.mid_batch_reset: valid = valid_from_done(samples.done) loss = valid_mean(losses, valid) KL_div *= valid else: loss = torch.mean(losses) return loss, KL_div
def vae_loss(self, samples): observation = samples.observation[0] # [T,B,C,H,W]->[B,C,H,W] target_observation = samples.observation[self.delta_T] if self.delta_T > 0: action = samples.action[:-1] # [T-1,B,A] don't need the last one if self.onehot_action: action = self.distribution.to_onehot(action) t, b = action.shape[:2] action = action.transpose(1, 0) # [B,T-1,A] action = action.reshape(b, -1) else: action = None observation, target_observation, action = buffer_to( (observation, target_observation, action), device=self.device ) h, conv_out = self.encoder(observation) z, mu, logvar = self.vae_head(h, action) recon_z = self.decoder(z) if target_observation.dtype == torch.uint8: target_observation = target_observation.type(torch.float) target_observation = target_observation.mul_(1 / 255.) b, c, h, w = target_observation.shape recon_losses = F.binary_cross_entropy( input=recon_z.reshape(b * c, h, w), target=target_observation.reshape(b * c, h, w), reduction="none", ) if self.delta_T > 0: valid = valid_from_done(samples.done).type(torch.bool) # [T,B] valid = valid[-1] # [B] valid = valid.to(self.device) else: valid = None # all are valid recon_losses = recon_losses.view(b, c, h, w).sum(dim=(2, 3)) # sum over H,W recon_losses = recon_losses.mean(dim=1) # mean over C (o/w loss is HUGE) recon_loss = valid_mean(recon_losses, valid=valid) # mean over batch kl_losses = 1 + logvar - mu.pow(2) - logvar.exp() kl_losses = kl_losses.sum(dim=-1) # sum over latent dimension kl_loss = -0.5 * valid_mean(kl_losses, valid=valid) # mean over batch kl_loss = self.kl_coeff * kl_loss return recon_loss, kl_loss, conv_out
def loss(self, samples): """Samples have leading batch dimension [B,..] (but not time).""" agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action), device=self.agent.device) # Move to device once, re-use. q1, q2 = self.agent.q(*agent_inputs, action) with torch.no_grad(): target_v = self.agent.target_v(*target_inputs) disc = self.discount**self.n_step_return y = (self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_v) if self.mid_batch_reset and not self.agent.recurrent: valid = None # OR: torch.ones_like(samples.done, dtype=torch.float) else: valid = valid_from_done(samples.done) q1_loss = 0.5 * valid_mean((y - q1)**2, valid) q2_loss = 0.5 * valid_mean((y - q2)**2, valid) v = self.agent.v(*agent_inputs) new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) if not self.reparameterize: new_action = new_action.detach() # No grad. log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) v_target = (min_log_target - log_pi + prior_log_pi).detach() # No grad. v_loss = 0.5 * valid_mean((v - v_target)**2, valid) if self.reparameterize: pi_losses = log_pi - min_log_target else: pi_factor = (v - v_target).detach() # No grad. pi_losses = log_pi * pi_factor if self.policy_output_regularization > 0: pi_losses += torch.sum( self.policy_output_regularization * 0.5 * pi_mean**2 + pi_log_std**2, dim=-1) pi_loss = valid_mean(pi_losses, valid) losses = (q1_loss, q2_loss, v_loss, pi_loss) values = tuple(val.detach() for val in (q1, q2, v, pi_mean, pi_log_std)) return losses, values
def rl_loss(self, latent, action, return_n, done_n, prev_action, prev_reward, next_state, next_prev_action, next_prev_reward, is_weights, done): delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1) z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms) # Make 2-D tensor of contracted z_domain for each data point, # with zeros where next value should not be added. next_z = z * (self.discount**self.n_step_return) # [P'] next_z = torch.ger(1 - done_n.float(), next_z) # [B,P'] ret = return_n.unsqueeze(1) # [B,1] next_z = torch.clamp(ret + next_z, self.V_min, self.V_max) # [B,P'] z_bc = z.view(1, -1, 1) # [1,P,1] next_z_bc = next_z.unsqueeze(1) # [B,1,P'] abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1) # Most 0. # projection_coeffs is a 3-D tensor: [B,P,P'] # dim-0: independent data entries # dim-1: base_z atoms (remains after projection) # dim-2: next_z atoms (summed in projection) with torch.no_grad(): target_ps = self.agent.target(next_state, next_prev_action, next_prev_reward) # [B,A,P'] if self.double_dqn: next_ps = self.agent(next_state, next_prev_action, next_prev_reward) # [B,A,P'] next_qs = torch.tensordot(next_ps, z, dims=1) # [B,A] next_a = torch.argmax(next_qs, dim=-1) # [B] else: target_qs = torch.tensordot(target_ps, z, dims=1) # [B,A] next_a = torch.argmax(target_qs, dim=-1) # [B] target_p_unproj = select_at_indexes(next_a, target_ps) # [B,P'] target_p_unproj = target_p_unproj.unsqueeze(1) # [B,1,P'] target_p = (target_p_unproj * projection_coeffs).sum(-1) # [B,P] ps = self.agent.head_forward(latent, prev_action, prev_reward) # [B,A,P] p = select_at_indexes(action, ps) # [B,P] p = torch.clamp(p, EPS, 1) # NaN-guard. losses = -torch.sum(target_p * torch.log(p), dim=1) # Cross-entropy. if self.prioritized_replay: losses *= is_weights target_p = torch.clamp(target_p, EPS, 1) KL_div = torch.sum(target_p * (torch.log(target_p) - torch.log(p.detach())), dim=1) KL_div = torch.clamp(KL_div, EPS, 1 / EPS) # Avoid <0 from NaN-guard. if not self.mid_batch_reset: valid = valid_from_done(done[1]) loss = valid_mean(losses, valid) KL_div *= valid else: loss = torch.mean(losses) return loss, KL_div
def q_loss(self, samples): if self.mid_batch_reset and not self.agent.recurrent: valid = torch.ones_like(samples.done, dtype=torch.float) # or None else: valid = valid_from_done(samples.done) if self.bootstrap_timelimit: # To avoid non-use of bootstrap when environment is 'done' due to # time-limit, turn off training on these samples. valid *= 1 - samples.timeout_n.float() # Run the convolution only once, return so pi_loss can use it. if self.store_latent: conv_out = None q_inputs = samples.agent_inputs else: conv_out = self.agent.conv(samples.agent_inputs.observation) if self.stop_conv_grad: conv_out = conv_out.detach() q_inputs = samples.agent_inputs._replace(observation=conv_out) # Q LOSS. q1, q2 = self.agent.q(*q_inputs, samples.action) with torch.no_grad(): # Run the target convolution only once. if self.store_latent: target_inputs = samples.target_inputs else: target_conv_out = self.agent.target_conv( samples.target_inputs.observation ) target_inputs = samples.target_inputs._replace( observation=target_conv_out ) target_action, target_log_pi, _ = self.agent.pi(*target_inputs) target_q1, target_q2 = self.agent.target_q(*target_inputs, target_action) min_target_q = torch.min(target_q1, target_q2) target_value = min_target_q - self._alpha * target_log_pi disc = self.discount ** self.n_step_return y = ( self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_value ) q1_loss = 0.5 * valid_mean((y - q1) ** 2, valid) q2_loss = 0.5 * valid_mean((y - q2) ** 2, valid) return q1_loss, q2_loss, valid, conv_out, q1.detach(), q2.detach()
def loss(self, samples): """ Computes the training loss: policy_loss + value_loss + entropy_loss. Policy loss: log-likelihood of actions * advantages Value loss: 0.5 * (estimated_value - return) ^ 2 Organizes agent inputs from training samples, calls the agent instance to run forward pass on training data, and uses the ``agent.distribution`` to compute likelihoods and entropies. Valid for feedforward or recurrent agents. """ agent_inputs = AgentInputs( observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) if self.agent.recurrent: init_rnn_state = samples.agent.agent_info.prev_rnn_state[ 0] # T = 0. # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent( *agent_inputs, init_rnn_state, device=agent_inputs.prev_action.device) else: dist_info, value = self.agent( *agent_inputs, device=agent_inputs.prev_action.device) # TODO: try to compute everyone on device. return_, advantage, valid = self.process_returns(samples) dist = self.agent.distribution logli = dist.log_likelihood(samples.agent.action, dist_info) pi_loss = -valid_mean(logli * advantage, valid) value_error = 0.5 * (value - return_)**2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = -self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, pi_loss, value_loss, entropy, perplexity
def compute_loss(self, observations, next_observations, actions, valid): # dimension add for when you have only one environment if actions.dim() == 2: actions = actions.unsqueeze(1) phi1, phi2, predicted_phi2, predicted_action = self.forward( observations, next_observations, actions) actions = torch.max(actions.view(-1, *actions.shape[2:]), 1)[1] # convert action to (T * B, action_size) inverse_loss = nn.functional.cross_entropy( predicted_action.view(-1, *predicted_action.shape[2:]), actions.detach(), reduction='none').view(phi1.shape[0], phi1.shape[1]) forward_loss = nn.functional.mse_loss( predicted_phi2, phi2.detach(), reduction='none').sum(-1) / self.feature_size inverse_loss = valid_mean(inverse_loss, valid.detach()) forward_loss = valid_mean(forward_loss, valid.detach()) return self.inverse_loss_wt * inverse_loss, self.forward_loss_wt * forward_loss
def loss(self, samples): """Samples have leading batch dimension [B,..] (but not time).""" agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action)) qs = self.agent.q(*agent_inputs, action) with torch.no_grad(): target_v = self.agent.target_v(*target_inputs).detach() disc = self.discount**self.n_step_return y = (self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_v) if self.mid_batch_reset and not self.agent.recurrent: valid = None # OR: torch.ones_like(samples.done, dtype=torch.float) else: valid = valid_from_done(samples.done) q_losses = [0.5 * valid_mean((y - q)**2, valid) for q in qs] new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs) if not self.reparameterize: new_action = new_action.detach() # No grad. log_targets = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(torch.stack(log_targets, dim=0), dim=0)[0] prior_log_pi = self.get_action_prior(new_action.cpu()) if self.reparameterize: alpha = self.agent.log_alpha.exp().detach() pi_losses = alpha * log_pi - min_log_target - prior_log_pi if self.policy_output_regularization > 0: pi_losses += torch.sum( self.policy_output_regularization * 0.5 * pi_mean**2 + pi_log_std**2, dim=-1) pi_loss = valid_mean(pi_losses, valid) # Calculate log_alpha loss alpha_loss = -valid_mean(self.agent.log_alpha * (log_pi + self.target_entropy).detach()) losses = (pi_loss, alpha_loss) values = tuple(val.detach() for val in (pi_mean, pi_log_std, alpha)) q_values = tuple(q.detach() for q in qs) return q_losses, losses, values, q_values
def beta_kl_losses( self, agent_inputs, action, return_, advantage, valid, old_dist_info, c_return, c_advantage, init_rnn_state=None, ): if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") r_dist_info, c_dist_info = self.agent.beta_dist_infos( *agent_inputs, init_rnn_state) else: r_dist_info, c_dist_info = self.agent.beta_dist_infos( *agent_inputs) dist = self.agent.distribution r_ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=r_dist_info) surr_1 = r_ratio * advantage r_clipped_ratio = torch.clamp(r_ratio, 1.0 - self.ratio_clip, 1.0 + self.ratio_clip) surr_2 = r_clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) beta_r_loss = -valid_mean(surrogate, valid) c_ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=c_dist_info) c_surr_1 = c_ratio * c_advantage c_clipped_ratio = torch.clamp(c_ratio, 1.0 - self.ratio_clip, 1.0 + self.ratio_clip) c_surr_2 = c_clipped_ratio * c_advantage c_surrogate = torch.max(c_surr_1, c_surr_2) beta_c_loss = valid_mean(c_surrogate, valid) return beta_r_loss, beta_c_loss
def compute_loss(self, observations, next_observations, actions, valid): #------------------------------------------------------------# # hacky dimension add for when you have only one environment (debugging) if actions.dim() == 2: actions = actions.unsqueeze(1) #------------------------------------------------------------# phi2, predicted_phi2, _ = self.forward(observations, next_observations, actions) forward_loss = torch.tensor(0.0, device=self.device) forward_loss_1 = nn.functional.dropout( nn.functional.mse_loss( predicted_phi2[0], phi2.detach(), reduction='none'), p=0.2).sum(-1) / self.feature_size forward_loss += valid_mean(forward_loss_1, valid) forward_loss_2 = nn.functional.dropout( nn.functional.mse_loss( predicted_phi2[1], phi2.detach(), reduction='none'), p=0.2).sum(-1) / self.feature_size forward_loss += valid_mean(forward_loss_2, valid) forward_loss_3 = nn.functional.dropout( nn.functional.mse_loss( predicted_phi2[2], phi2.detach(), reduction='none'), p=0.2).sum(-1) / self.feature_size forward_loss += valid_mean(forward_loss_3, valid) forward_loss_4 = nn.functional.dropout( nn.functional.mse_loss( predicted_phi2[3], phi2.detach(), reduction='none'), p=0.2).sum(-1) / self.feature_size forward_loss += valid_mean(forward_loss_4, valid) forward_loss_5 = nn.functional.dropout( nn.functional.mse_loss( predicted_phi2[4], phi2.detach(), reduction='none'), p=0.2).sum(-1) / self.feature_size forward_loss += valid_mean(forward_loss_5, valid) return self.forward_loss_wt * forward_loss
def q_loss(self, samples, valid): """Samples have leading batch dimension [B,..] (but not time).""" q = self.agent.q(*samples.agent_inputs, samples.action) with torch.no_grad(): target_q = self.agent.target_q_at_mu(*samples.target_inputs) disc = self.discount**self.n_step_return y = samples.return_ + (1 - samples.done_n.float()) * disc * target_q y = torch.clamp(y, -self.q_target_clip, self.q_target_clip) q_losses = 0.5 * (y - q)**2 q_loss = valid_mean(q_losses, valid) # valid can be None. return q_loss
def compute_loss(self, observations, valid): phi, predicted_phi, T, B = self.forward(observations, done=None) forward_loss = nn.functional.mse_loss( predicted_phi, phi.detach(), reduction='none').sum(-1) / self.feature_size mask = torch.rand(forward_loss.shape) mask = (mask > self.drop_probability).type(torch.FloatTensor).to( self.device) forward_loss = forward_loss * mask.detach() forward_loss = valid_mean(forward_loss, valid.detach()) return forward_loss
def q_loss(self, samples, valid): q1, q2 = self.agent.q(*samples.agent_inputs, samples.action) with torch.no_grad(): target_q1, target_q2 = self.agent.target_q_at_mu( *samples.target_inputs) # Includes target action noise. target_q = torch.min(target_q1, target_q2) disc = self.discount**self.n_step_return y = samples.return_ + (1 - samples.done_n.float()) * disc * target_q q1_losses = 0.5 * (y - q1)**2 q2_losses = 0.5 * (y - q2)**2 q_loss = valid_mean(q1_losses + q2_losses, valid) # valid can be None. return q_loss
def continuous_actions_loss(self, advantage_mask, phi, action, dist_info, old_dist_info, valid, opt_info): d = np.prod(action.shape[-1]) distribution = torch.distributions.normal.Normal( loc=dist_info.mean, scale=dist_info.log_std) pi_loss = -torch.sum( advantage_mask * (phi.detach() * distribution.log_prob(action).sum(dim=-1))) # pi_loss = - torch.sum(advantage_mask * (phi.detach() * self.agent.distribution.log_likelihood(action, dist_info))) new_std = dist_info.log_std old_std = old_dist_info.log_std old_covariance = torch.diag_embed(old_std) old_covariance_inverse = torch.diag_embed(1 / old_std) new_covariance_inverse = torch.diag_embed(1 / new_std) old_covariance_determinant = torch.prod(old_std, dim=-1) new_covariance_determinant = torch.prod(new_std, dim=-1) mu_kl = 0.5 * utils.batched_quadratic_form( dist_info.mean - old_dist_info.mean, old_covariance_inverse) trace = utils.batched_trace( torch.matmul(new_covariance_inverse, old_covariance)) sigma_kl = 0.5 * (trace - d + torch.log( new_covariance_determinant / old_covariance_determinant)) alpha_mu_loss = valid_mean( self.alpha_mu * (self.epsilon_alpha_mu - mu_kl.detach()) + self.alpha_mu.detach() * mu_kl, valid) alpha_sigma_loss = valid_mean( self.alpha_sigma * (self.epsilon_alpha_sigma - sigma_kl.detach()) + self.alpha_sigma.detach() * sigma_kl, valid) opt_info.alpha_mu.append(self.alpha_mu.item()) opt_info.alpha_sigma.append(self.alpha_sigma.item()) opt_info.alpha_mu_loss.append(alpha_mu_loss.item()) opt_info.mu_kl.append(valid_mean(mu_kl, valid).item()) opt_info.sigma_kl.append(valid_mean(sigma_kl, valid).item()) opt_info.alpha_sigma_loss.append( valid_mean(self.epsilon_alpha_sigma - sigma_kl, valid).item()) opt_info.pi_mu.append(dist_info.mean.mean().item()) opt_info.pi_log_std.append(dist_info.log_std.mean().item()) return pi_loss, alpha_mu_loss + alpha_sigma_loss, opt_info
def loss(self, samples): """Samples have leading batch dimension [B,..] (but not time).""" agent_inputs, target_inputs, action = buffer_to( (samples.agent_inputs, samples.target_inputs, samples.action)) q1, q2 = self.agent.q(*agent_inputs, action) with torch.no_grad(): target_v = self.agent.target_v(*target_inputs).detach() disc = self.discount**self.n_step_return y = (self.reward_scale * samples.return_ + (1 - samples.done_n.float()) * disc * target_v) if self.mid_batch_reset and not self.agent.recurrent: valid = None # OR: torch.ones_like(samples.done, dtype=torch.float) else: valid = valid_from_done(samples.done) q1_loss = 0.5 * valid_mean((y - q1)**2, valid) q2_loss = 0.5 * valid_mean((y - q2)**2, valid) new_action, log_pi, _ = self.agent.pi(*agent_inputs) if not self.reparameterize: new_action = new_action.detach() # No grad. log_target1, log_target2 = self.agent.q(*agent_inputs, new_action) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) if self.reparameterize: alpha = self.agent.log_alpha.exp().detach() pi_losses = alpha * log_pi - min_log_target - prior_log_pi pi_loss = valid_mean(pi_losses, valid) # Calculate log_alpha loss alpha_loss = -valid_mean(self.agent.log_alpha * (log_pi + self.target_entropy).detach()) losses = (q1_loss, q2_loss, pi_loss, alpha_loss) values = tuple(val.detach() for val in (q1, q2, alpha)) return losses, values
def discrete_actions_loss(self, advantage_mask, phi, action, dist_info, old_dist_info, valid, opt_info): dist = self.agent.distribution pi_loss = -torch.sum(advantage_mask * (phi.detach() * dist.log_likelihood( action.contiguous(), dist_info))) policy_kl = dist.kl(old_dist_info, dist_info) alpha_loss = valid_mean( self.alpha * (self.epsilon_alpha - policy_kl.detach()) + self.alpha.detach() * policy_kl, valid) opt_info.alpha_loss.append(alpha_loss.item()) opt_info.alpha.append(self.alpha.item()) opt_info.policy_kl.append(policy_kl.mean().item()) opt_info.entropy.append(dist.entropy(dist_info).mean().item()) return pi_loss, alpha_loss, opt_info
def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info, init_rnn_state=None): if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: dist_info, value = self.agent(*agent_inputs) # TODO IF MULTIAGENT, reshape things # Just kidding. It seems that the dist.* functions can operate on multi-dimensional tensors # Need to double check this is true, but things seem to be ok # (entropy and likelihood ratios are computed along last dim) dist = self.agent.distribution ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip, 1. + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = - valid_mean(surrogate, valid) value_error = 0.5 * (value - return_) ** 2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = - self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity