def grad_process_and_td_error_fn(policy: Policy, optimizer: "torch.optim.Optimizer", loss: TensorType) -> Dict[str, TensorType]: # Clip grads if configured. return apply_grad_clipping(policy, optimizer, loss)
def extra_grad_process(self, local_optimizer, loss): return apply_grad_clipping(self, local_optimizer, loss)
def extra_grad_process( self, optimizer: torch.optim.Optimizer, loss: TensorType ) -> Dict[str, TensorType]: # Clip grads if configured. return apply_grad_clipping(self, optimizer, loss)
def extra_grad_process(self, optimizer: "torch.optim.Optimizer", loss: TensorType) -> Dict[str, TensorType]: return apply_grad_clipping(self, optimizer, loss)
def learn_on_batch(self, samples): obs_batch, action_mask, env_global_state = self._unpack_observation( samples[SampleBatch.CUR_OBS]) ( next_obs_batch, next_action_mask, next_env_global_state, ) = self._unpack_observation(samples[SampleBatch.NEXT_OBS]) group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS]) input_list = [ group_rewards, action_mask, next_action_mask, samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES], obs_batch, next_obs_batch, ] if self.has_env_global_state: input_list.extend([env_global_state, next_env_global_state]) output_list, _, seq_lens = chop_into_sequences( episode_ids=samples[SampleBatch.EPS_ID], unroll_ids=samples[SampleBatch.UNROLL_ID], agent_indices=samples[SampleBatch.AGENT_INDEX], feature_columns=input_list, state_columns=[], # RNN states not used here max_seq_len=self.config["model"]["max_seq_len"], dynamic_max=True, ) # These will be padded to shape [B * T, ...] if self.has_env_global_state: ( rew, action_mask, next_action_mask, act, dones, obs, next_obs, env_global_state, next_env_global_state, ) = output_list else: ( rew, action_mask, next_action_mask, act, dones, obs, next_obs, ) = output_list B, T = len(seq_lens), max(seq_lens) def to_batches(arr, dtype): new_shape = [B, T] + list(arr.shape[1:]) return torch.as_tensor(np.reshape(arr, new_shape), dtype=dtype, device=self.device) rewards = to_batches(rew, torch.float) actions = to_batches(act, torch.long) obs = to_batches(obs, torch.float).reshape( [B, T, self.n_agents, self.obs_size]) action_mask = to_batches(action_mask, torch.float) next_obs = to_batches(next_obs, torch.float).reshape( [B, T, self.n_agents, self.obs_size]) next_action_mask = to_batches(next_action_mask, torch.float) if self.has_env_global_state: env_global_state = to_batches(env_global_state, torch.float) next_env_global_state = to_batches(next_env_global_state, torch.float) # TODO(ekl) this treats group termination as individual termination terminated = (to_batches(dones, torch.float).unsqueeze(2).expand( B, T, self.n_agents)) # Create mask for where index is < unpadded sequence length filled = np.reshape(np.tile(np.arange(T, dtype=np.float32), B), [B, T]) < np.expand_dims(seq_lens, 1) mask = (torch.as_tensor(filled, dtype=torch.float, device=self.device).unsqueeze(2).expand( B, T, self.n_agents)) # Compute loss loss_out, mask, masked_td_error, chosen_action_qvals, targets = self.loss( rewards, actions, terminated, mask, obs, next_obs, action_mask, next_action_mask, env_global_state, next_env_global_state, ) # Optimise self.rmsprop_optimizer.zero_grad() loss_out.backward() grad_norm_info = apply_grad_clipping(self, self.rmsprop_optimizer, loss_out) self.rmsprop_optimizer.step() mask_elems = mask.sum().item() stats = { "loss": loss_out.item(), "td_error_abs": masked_td_error.abs().sum().item() / mask_elems, "q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems, "target_mean": (targets * mask).sum().item() / mask_elems, } stats.update(grad_norm_info) return {LEARNER_STATS_KEY: stats}