Exemplo n.º 1
0
def grad_process_and_td_error_fn(policy: Policy,
                                 optimizer: "torch.optim.Optimizer",
                                 loss: TensorType) -> Dict[str, TensorType]:
    # Clip grads if configured.
    return apply_grad_clipping(policy, optimizer, loss)
Exemplo n.º 2
0
 def extra_grad_process(self, local_optimizer, loss):
     return apply_grad_clipping(self, local_optimizer, loss)
Exemplo n.º 3
0
 def extra_grad_process(
     self, optimizer: torch.optim.Optimizer, loss: TensorType
 ) -> Dict[str, TensorType]:
     # Clip grads if configured.
     return apply_grad_clipping(self, optimizer, loss)
Exemplo n.º 4
0
 def extra_grad_process(self, optimizer: "torch.optim.Optimizer",
                        loss: TensorType) -> Dict[str, TensorType]:
     return apply_grad_clipping(self, optimizer, loss)
Exemplo n.º 5
0
    def learn_on_batch(self, samples):
        obs_batch, action_mask, env_global_state = self._unpack_observation(
            samples[SampleBatch.CUR_OBS])
        (
            next_obs_batch,
            next_action_mask,
            next_env_global_state,
        ) = self._unpack_observation(samples[SampleBatch.NEXT_OBS])
        group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])

        input_list = [
            group_rewards,
            action_mask,
            next_action_mask,
            samples[SampleBatch.ACTIONS],
            samples[SampleBatch.DONES],
            obs_batch,
            next_obs_batch,
        ]
        if self.has_env_global_state:
            input_list.extend([env_global_state, next_env_global_state])

        output_list, _, seq_lens = chop_into_sequences(
            episode_ids=samples[SampleBatch.EPS_ID],
            unroll_ids=samples[SampleBatch.UNROLL_ID],
            agent_indices=samples[SampleBatch.AGENT_INDEX],
            feature_columns=input_list,
            state_columns=[],  # RNN states not used here
            max_seq_len=self.config["model"]["max_seq_len"],
            dynamic_max=True,
        )
        # These will be padded to shape [B * T, ...]
        if self.has_env_global_state:
            (
                rew,
                action_mask,
                next_action_mask,
                act,
                dones,
                obs,
                next_obs,
                env_global_state,
                next_env_global_state,
            ) = output_list
        else:
            (
                rew,
                action_mask,
                next_action_mask,
                act,
                dones,
                obs,
                next_obs,
            ) = output_list
        B, T = len(seq_lens), max(seq_lens)

        def to_batches(arr, dtype):
            new_shape = [B, T] + list(arr.shape[1:])
            return torch.as_tensor(np.reshape(arr, new_shape),
                                   dtype=dtype,
                                   device=self.device)

        rewards = to_batches(rew, torch.float)
        actions = to_batches(act, torch.long)
        obs = to_batches(obs, torch.float).reshape(
            [B, T, self.n_agents, self.obs_size])
        action_mask = to_batches(action_mask, torch.float)
        next_obs = to_batches(next_obs, torch.float).reshape(
            [B, T, self.n_agents, self.obs_size])
        next_action_mask = to_batches(next_action_mask, torch.float)
        if self.has_env_global_state:
            env_global_state = to_batches(env_global_state, torch.float)
            next_env_global_state = to_batches(next_env_global_state,
                                               torch.float)

        # TODO(ekl) this treats group termination as individual termination
        terminated = (to_batches(dones, torch.float).unsqueeze(2).expand(
            B, T, self.n_agents))

        # Create mask for where index is < unpadded sequence length
        filled = np.reshape(np.tile(np.arange(T, dtype=np.float32), B),
                            [B, T]) < np.expand_dims(seq_lens, 1)
        mask = (torch.as_tensor(filled, dtype=torch.float,
                                device=self.device).unsqueeze(2).expand(
                                    B, T, self.n_agents))

        # Compute loss
        loss_out, mask, masked_td_error, chosen_action_qvals, targets = self.loss(
            rewards,
            actions,
            terminated,
            mask,
            obs,
            next_obs,
            action_mask,
            next_action_mask,
            env_global_state,
            next_env_global_state,
        )

        # Optimise
        self.rmsprop_optimizer.zero_grad()
        loss_out.backward()
        grad_norm_info = apply_grad_clipping(self, self.rmsprop_optimizer,
                                             loss_out)
        self.rmsprop_optimizer.step()

        mask_elems = mask.sum().item()
        stats = {
            "loss": loss_out.item(),
            "td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
            "q_taken_mean":
            (chosen_action_qvals * mask).sum().item() / mask_elems,
            "target_mean": (targets * mask).sum().item() / mask_elems,
        }
        stats.update(grad_norm_info)

        return {LEARNER_STATS_KEY: stats}