def compute_estimate(self,
                      mini_batch: AgentBuffer,
                      use_vail_noise: bool = False) -> torch.Tensor:
     """
     Given a mini_batch, computes the estimate (How much the discriminator believes
     the data was sampled from the demonstration data).
     :param mini_batch: The AgentBuffer of data
     :param use_vail_noise: Only when using VAIL : If true, will sample the code, if
     false, will return the mean of the code.
     """
     vec_inputs, vis_inputs = self.get_state_inputs(mini_batch)
     if self._settings.use_actions:
         actions = self.get_action_input(mini_batch)
         dones = torch.as_tensor(mini_batch["done"],
                                 dtype=torch.float).unsqueeze(1)
         action_inputs = torch.cat([actions, dones], dim=1)
         hidden, _ = self.encoder(vec_inputs, vis_inputs, action_inputs)
     else:
         hidden, _ = self.encoder(vec_inputs, vis_inputs)
     z_mu: Optional[torch.Tensor] = None
     if self._settings.use_vail:
         z_mu = self._z_mu_layer(hidden)
         hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
     estimate = self._estimator(hidden)
     return estimate, z_mu
Exemple #2
0
 def compute_gradient_magnitude(self, policy_batch: AgentBuffer,
                                expert_batch: AgentBuffer) -> torch.Tensor:
     """
     Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp.
     for off-policy. Compute gradients w.r.t randomly interpolated input.
     """
     policy_inputs = self.get_state_inputs(policy_batch)
     expert_inputs = self.get_state_inputs(expert_batch)
     interp_inputs = []
     for policy_input, expert_input in zip(policy_inputs, expert_inputs):
         obs_epsilon = torch.rand(policy_input.shape)
         interp_input = obs_epsilon * policy_input + (
             1 - obs_epsilon) * expert_input
         interp_input.requires_grad = True  # For gradient calculation
         interp_inputs.append(interp_input)
     if self._settings.use_actions:
         policy_action = self.get_action_input(policy_batch)
         expert_action = self.get_action_input(expert_batch)
         action_epsilon = torch.rand(policy_action.shape)
         policy_dones = torch.as_tensor(policy_batch["done"],
                                        dtype=torch.float).unsqueeze(1)
         expert_dones = torch.as_tensor(expert_batch["done"],
                                        dtype=torch.float).unsqueeze(1)
         dones_epsilon = torch.rand(policy_dones.shape)
         action_inputs = torch.cat(
             [
                 action_epsilon * policy_action +
                 (1 - action_epsilon) * expert_action,
                 dones_epsilon * policy_dones +
                 (1 - dones_epsilon) * expert_dones,
             ],
             dim=1,
         )
         action_inputs.requires_grad = True
         hidden, _ = self.encoder(interp_inputs, action_inputs)
         encoder_input = tuple(interp_inputs + [action_inputs])
     else:
         hidden, _ = self.encoder(interp_inputs)
         encoder_input = tuple(interp_inputs)
     if self._settings.use_vail:
         use_vail_noise = True
         z_mu = self._z_mu_layer(hidden)
         hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
     estimate = self._estimator(hidden).squeeze(1).sum()
     gradient = torch.autograd.grad(estimate,
                                    encoder_input,
                                    create_graph=True)[0]
     # Norm's gradient could be NaN at 0. Use our own safe_norm
     safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
     gradient_mag = torch.mean((safe_norm - 1)**2)
     return gradient_mag