def compute_estimate(self, mini_batch: AgentBuffer, use_vail_noise: bool = False) -> torch.Tensor: """ Given a mini_batch, computes the estimate (How much the discriminator believes the data was sampled from the demonstration data). :param mini_batch: The AgentBuffer of data :param use_vail_noise: Only when using VAIL : If true, will sample the code, if false, will return the mean of the code. """ inputs = self.get_state_inputs(mini_batch) if self._settings.use_actions: actions = self.get_action_input(mini_batch) dones = torch.as_tensor(mini_batch[BufferKey.DONE], dtype=torch.float).unsqueeze(1) action_inputs = torch.cat([actions, dones], dim=1) hidden, _ = self.encoder(inputs, action_inputs) else: hidden, _ = self.encoder(inputs) z_mu: Optional[torch.Tensor] = None if self._settings.use_vail: z_mu = self._z_mu_layer(hidden) hidden = z_mu + torch.randn_like( z_mu) * self._z_sigma * use_vail_noise estimate = self._estimator(hidden) return estimate, z_mu
def compute_gradient_magnitude(self, policy_batch: AgentBuffer, expert_batch: AgentBuffer) -> torch.Tensor: """ Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp. for off-policy. Compute gradients w.r.t randomly interpolated input. """ policy_inputs = self.get_state_inputs(policy_batch) expert_inputs = self.get_state_inputs(expert_batch) interp_inputs = [] for policy_input, expert_input in zip(policy_inputs, expert_inputs): obs_epsilon = torch.rand(policy_input.shape) interp_input = obs_epsilon * policy_input + ( 1 - obs_epsilon) * expert_input interp_input.requires_grad = True # For gradient calculation interp_inputs.append(interp_input) if self._settings.use_actions: policy_action = self.get_action_input(policy_batch) expert_action = self.get_action_input(expert_batch) action_epsilon = torch.rand(policy_action.shape) policy_dones = torch.as_tensor(policy_batch[BufferKey.DONE], dtype=torch.float).unsqueeze(1) expert_dones = torch.as_tensor(expert_batch[BufferKey.DONE], dtype=torch.float).unsqueeze(1) dones_epsilon = torch.rand(policy_dones.shape) action_inputs = torch.cat( [ action_epsilon * policy_action + (1 - action_epsilon) * expert_action, dones_epsilon * policy_dones + (1 - dones_epsilon) * expert_dones, ], dim=1, ) action_inputs.requires_grad = True hidden, _ = self.encoder(interp_inputs, action_inputs) encoder_input = tuple(interp_inputs + [action_inputs]) else: hidden, _ = self.encoder(interp_inputs) encoder_input = tuple(interp_inputs) if self._settings.use_vail: use_vail_noise = True z_mu = self._z_mu_layer(hidden) hidden = z_mu + torch.randn_like( z_mu) * self._z_sigma * use_vail_noise estimate = self._estimator(hidden).squeeze(1).sum() gradient = torch.autograd.grad(estimate, encoder_input, create_graph=True)[0] # Norm's gradient could be NaN at 0. Use our own safe_norm safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt() gradient_mag = torch.mean((safe_norm - 1)**2) return gradient_mag
def sample(self): sample = self.mean + torch.randn_like(self.mean) * self.std return sample