예제 #1
 def ppo_value_loss(
     values: Dict[str, torch.Tensor],
     old_values: Dict[str, torch.Tensor],
     returns: Dict[str, torch.Tensor],
     epsilon: float,
     loss_masks: torch.Tensor,
 ) -> torch.Tensor:
     Evaluates value loss for PPO.
     :param values: Value output of the current network.
     :param old_values: Value stored with experiences in buffer.
     :param returns: Computed returns.
     :param epsilon: Clipping value for value estimate.
     :param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
     value_losses = []
     for name, head in values.items():
         old_val_tensor = old_values[name]
         returns_tensor = returns[name]
         clipped_value_estimate = old_val_tensor + torch.clamp(
             head - old_val_tensor, -1 * epsilon, epsilon)
         v_opt_a = (returns_tensor - head)**2
         v_opt_b = (returns_tensor - clipped_value_estimate)**2
         value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b),
     value_loss = torch.mean(torch.stack(value_losses))
     return value_loss
예제 #2
def test_tanh_gaussian_dist_instance():
    act_size = 4
    dist_instance = TanhGaussianDistInstance(torch.zeros(1, act_size),
                                             torch.ones(1, act_size))
    for _ in range(10):
        action = dist_instance.sample()
        assert action.shape == (1, act_size)
        assert torch.max(action) < 1.0 and torch.min(action) > -1.0
예제 #3
 def compute_loss(
     self, policy_batch: AgentBuffer, expert_batch: AgentBuffer
 ) -> torch.Tensor:
     Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
     total_loss = torch.zeros(1)
     stats_dict: Dict[str, np.ndarray] = {}
     policy_estimate, policy_mu = self.compute_estimate(
         policy_batch, use_vail_noise=True
     expert_estimate, expert_mu = self.compute_estimate(
         expert_batch, use_vail_noise=True
     stats_dict["Policy/GAIL Policy Estimate"] = policy_estimate.mean().item()
     stats_dict["Policy/GAIL Expert Estimate"] = expert_estimate.mean().item()
     discriminator_loss = -(
         torch.log(expert_estimate + self.EPSILON)
         + torch.log(1.0 - policy_estimate + self.EPSILON)
     stats_dict["Losses/GAIL Loss"] = discriminator_loss.item()
     total_loss += discriminator_loss
     if self._settings.use_vail:
         # KL divergence loss (encourage latent representation to be normal)
         kl_loss = torch.mean(
                 + (self._z_sigma ** 2).log()
                 - 0.5 * expert_mu ** 2
                 - 0.5 * policy_mu ** 2
                 - (self._z_sigma ** 2),
         vail_loss = self._beta * (kl_loss - self.mutual_information)
         with torch.no_grad():
             self._beta.data = torch.max(
                 self._beta + self.alpha * (kl_loss - self.mutual_information),
         total_loss += vail_loss
         stats_dict["Policy/GAIL Beta"] = self._beta.item()
         stats_dict["Losses/GAIL KL Loss"] = kl_loss.item()
     if self.gradient_penalty_weight > 0.0:
         gradient_magnitude_loss = (
             * self.compute_gradient_magnitude(policy_batch, expert_batch)
         stats_dict["Policy/GAIL Grad Mag Loss"] = gradient_magnitude_loss.item()
         total_loss += gradient_magnitude_loss
     return total_loss, stats_dict
예제 #4
    def forward(
        obs_only: List[List[torch.Tensor]],
        obs: List[List[torch.Tensor]],
        actions: List[AgentAction],
        memories: Optional[torch.Tensor] = None,
        sequence_length: int = 1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        Returns sampled actions.
        If memory is enabled, return the memories as well.
        :param obs_only: Observations to be processed that do not have corresponding actions.
            These are encoded with the obs_encoder.
        :param obs: Observations to be processed that do have corresponding actions.
            After concatenation with actions, these are processed with obs_action_encoder.
        :param actions: After concatenation with obs, these are processed with obs_action_encoder.
        :param memories: If using memory, a Tensor of initial memories.
        :param sequence_length: If using memory, the sequence length.
        self_attn_masks = []
        self_attn_inputs = []
        concat_f_inp = []
        if obs:
            obs_attn_mask = self._get_masks_from_nans(obs)
            obs = self._copy_and_remove_nans_from_obs(obs, obs_attn_mask)
            for inputs, action in zip(obs, actions):
                encoded = self.observation_encoder(inputs)
                cat_encodes = [
                concat_f_inp.append(torch.cat(cat_encodes, dim=1))
            f_inp = torch.stack(concat_f_inp, dim=1)
            self_attn_inputs.append(self.obs_action_encoder(None, f_inp))

        concat_encoded_obs = []
        if obs_only:
            obs_only_attn_mask = self._get_masks_from_nans(obs_only)
            obs_only = self._copy_and_remove_nans_from_obs(
                obs_only, obs_only_attn_mask)
            for inputs in obs_only:
                encoded = self.observation_encoder(inputs)
            g_inp = torch.stack(concat_encoded_obs, dim=1)
            self_attn_inputs.append(self.obs_encoder(None, g_inp))

        encoded_entity = torch.cat(self_attn_inputs, dim=1)
        encoded_state = self.self_attn(encoded_entity, self_attn_masks)

        flipped_masks = 1 - torch.cat(self_attn_masks, dim=1)
        num_agents = torch.sum(flipped_masks, dim=1, keepdim=True)
        if torch.max(num_agents).item() > self._current_max_agents:
            self._current_max_agents = torch.nn.Parameter(torch.as_tensor(

        # num_agents will be -1 for a single agent and +1 when the current maximum is reached
        num_agents = num_agents * 2.0 / self._current_max_agents - 1

        encoding = self.linear_encoder(encoded_state)
        if self.use_lstm:
            # Resize to (batch, sequence length, encoding size)
            encoding = encoding.reshape([-1, sequence_length, self.h_size])
            encoding, memories = self.lstm(encoding, memories)
            encoding = encoding.reshape([-1, self.m_size // 2])
        encoding = torch.cat([encoding, num_agents], dim=1)
        return encoding, memories