def ppo_value_loss( self, values: Dict[str, torch.Tensor], old_values: Dict[str, torch.Tensor], returns: Dict[str, torch.Tensor], epsilon: float, loss_masks: torch.Tensor, ) -> torch.Tensor: """ Evaluates value loss for PPO. :param values: Value output of the current network. :param old_values: Value stored with experiences in buffer. :param returns: Computed returns. :param epsilon: Clipping value for value estimate. :param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences. """ value_losses = [] for name, head in values.items(): old_val_tensor = old_values[name] returns_tensor = returns[name] clipped_value_estimate = old_val_tensor + torch.clamp( head - old_val_tensor, -1 * epsilon, epsilon) v_opt_a = (returns_tensor - head)**2 v_opt_b = (returns_tensor - clipped_value_estimate)**2 value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) return value_loss
def test_tanh_gaussian_dist_instance(): torch.manual_seed(0) act_size = 4 dist_instance = TanhGaussianDistInstance(torch.zeros(1, act_size), torch.ones(1, act_size)) for _ in range(10): action = dist_instance.sample() assert action.shape == (1, act_size) assert torch.max(action) < 1.0 and torch.min(action) > -1.0
def compute_loss( self, policy_batch: AgentBuffer, expert_batch: AgentBuffer ) -> torch.Tensor: """ Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator. """ total_loss = torch.zeros(1) stats_dict: Dict[str, np.ndarray] = {} policy_estimate, policy_mu = self.compute_estimate( policy_batch, use_vail_noise=True ) expert_estimate, expert_mu = self.compute_estimate( expert_batch, use_vail_noise=True ) stats_dict["Policy/GAIL Policy Estimate"] = policy_estimate.mean().item() stats_dict["Policy/GAIL Expert Estimate"] = expert_estimate.mean().item() discriminator_loss = -( torch.log(expert_estimate + self.EPSILON) + torch.log(1.0 - policy_estimate + self.EPSILON) ).mean() stats_dict["Losses/GAIL Loss"] = discriminator_loss.item() total_loss += discriminator_loss if self._settings.use_vail: # KL divergence loss (encourage latent representation to be normal) kl_loss = torch.mean( -torch.sum( 1 + (self._z_sigma ** 2).log() - 0.5 * expert_mu ** 2 - 0.5 * policy_mu ** 2 - (self._z_sigma ** 2), dim=1, ) ) vail_loss = self._beta * (kl_loss - self.mutual_information) with torch.no_grad(): self._beta.data = torch.max( self._beta + self.alpha * (kl_loss - self.mutual_information), torch.tensor(0.0), ) total_loss += vail_loss stats_dict["Policy/GAIL Beta"] = self._beta.item() stats_dict["Losses/GAIL KL Loss"] = kl_loss.item() if self.gradient_penalty_weight > 0.0: gradient_magnitude_loss = ( self.gradient_penalty_weight * self.compute_gradient_magnitude(policy_batch, expert_batch) ) stats_dict["Policy/GAIL Grad Mag Loss"] = gradient_magnitude_loss.item() total_loss += gradient_magnitude_loss return total_loss, stats_dict
def forward( self, obs_only: List[List[torch.Tensor]], obs: List[List[torch.Tensor]], actions: List[AgentAction], memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns sampled actions. If memory is enabled, return the memories as well. :param obs_only: Observations to be processed that do not have corresponding actions. These are encoded with the obs_encoder. :param obs: Observations to be processed that do have corresponding actions. After concatenation with actions, these are processed with obs_action_encoder. :param actions: After concatenation with obs, these are processed with obs_action_encoder. :param memories: If using memory, a Tensor of initial memories. :param sequence_length: If using memory, the sequence length. """ self_attn_masks = [] self_attn_inputs = [] concat_f_inp = [] if obs: obs_attn_mask = self._get_masks_from_nans(obs) obs = self._copy_and_remove_nans_from_obs(obs, obs_attn_mask) for inputs, action in zip(obs, actions): encoded = self.observation_encoder(inputs) cat_encodes = [ encoded, action.to_flat(self.action_spec.discrete_branches), ] concat_f_inp.append(torch.cat(cat_encodes, dim=1)) f_inp = torch.stack(concat_f_inp, dim=1) self_attn_masks.append(obs_attn_mask) self_attn_inputs.append(self.obs_action_encoder(None, f_inp)) concat_encoded_obs = [] if obs_only: obs_only_attn_mask = self._get_masks_from_nans(obs_only) obs_only = self._copy_and_remove_nans_from_obs( obs_only, obs_only_attn_mask) for inputs in obs_only: encoded = self.observation_encoder(inputs) concat_encoded_obs.append(encoded) g_inp = torch.stack(concat_encoded_obs, dim=1) self_attn_masks.append(obs_only_attn_mask) self_attn_inputs.append(self.obs_encoder(None, g_inp)) encoded_entity = torch.cat(self_attn_inputs, dim=1) encoded_state = self.self_attn(encoded_entity, self_attn_masks) flipped_masks = 1 - torch.cat(self_attn_masks, dim=1) num_agents = torch.sum(flipped_masks, dim=1, keepdim=True) if torch.max(num_agents).item() > self._current_max_agents: self._current_max_agents = torch.nn.Parameter(torch.as_tensor( torch.max(num_agents).item()), requires_grad=False) # num_agents will be -1 for a single agent and +1 when the current maximum is reached num_agents = num_agents * 2.0 / self._current_max_agents - 1 encoding = self.linear_encoder(encoded_state) if self.use_lstm: # Resize to (batch, sequence length, encoding size) encoding = encoding.reshape([-1, sequence_length, self.h_size]) encoding, memories = self.lstm(encoding, memories) encoding = encoding.reshape([-1, self.m_size // 2]) encoding = torch.cat([encoding, num_agents], dim=1) return encoding, memories