Ejemplo n.º 1
0
    def test_nested(self):
        eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
        agent_ids = [1, 1, 1, 1, 1, 1, 1, 1]
        f = [
            {
                "a": np.array([1, 2, 3, 4, 13, 14, 15, 16]),
                "b": {"ba": np.array([5, 6, 7, 8, 9, 10, 11, 12])},
            }
        ]
        s = [[209, 208, 207, 109, 108, 107, 106, 105]]

        f_pad, s_init, seq_lens = chop_into_sequences(
            episode_ids=eps_ids,
            unroll_ids=np.ones_like(eps_ids),
            agent_indices=agent_ids,
            feature_columns=f,
            state_columns=s,
            max_seq_len=4,
            handle_nested_data=True,
        )
        check(
            f_pad,
            [
                [
                    [1, 2, 3, 0, 4, 13, 14, 15, 16, 0, 0, 0],
                    [5, 6, 7, 0, 8, 9, 10, 11, 12, 0, 0, 0],
                ]
            ],
        )
        self.assertEqual([s.tolist() for s in s_init], [[209, 109, 105]])
        self.assertEqual(seq_lens.tolist(), [3, 4, 1])
Ejemplo n.º 2
0
    def _get_loss_inputs_dict(self,
                              batch,
                              shuffle,
                              cross_policy_obj,
                              policy_id=None):
        """When training, add the required data into the feed_dict."""
        feed_dict = {}

        if hasattr(self, "_AddLossMixin_initialized"):
            assert self._AddLossMixin_initialized
            # parse the cross-policy info and put in feed_dict.
            joint_obs_ph = self._loss_input_dict[JOINT_OBS]
            feed_dict[joint_obs_ph] = cross_policy_obj[JOINT_OBS]

            replay_ph = self._loss_input_dict[PEER_ACTION]
            feed_dict[replay_ph] = np.concatenate([
                act for pid, act in cross_policy_obj[PEER_ACTION].items()
                if pid != policy_id
            ])  # exclude policy itself action
        """The below codes are copied from rllib. """
        if self._batch_divisibility_req > 1:
            meets_divisibility_reqs = (
                len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req
                == 0
                and max(batch[SampleBatch.AGENT_INDEX]) == 0)  # not multiagent
        else:
            meets_divisibility_reqs = True
        # Simple case: not RNN nor do we need to pad
        if not self._state_inputs and meets_divisibility_reqs:
            if shuffle:
                batch.shuffle()
            for k, ph in self._loss_inputs:
                if k in batch:  # Attention! We add a condition here.
                    feed_dict[ph] = batch[k]
            return feed_dict
        if self._state_inputs:
            max_seq_len = self._max_seq_len
            dynamic_max = True
        else:
            max_seq_len = self._batch_divisibility_req
            dynamic_max = False
        # RNN or multi-agent case
        feature_keys = [k for k, v in self._loss_inputs]
        state_keys = [
            "state_in_{}".format(i) for i in range(len(self._state_inputs))
        ]
        feature_sequences, initial_states, seq_lens = chop_into_sequences(
            batch[SampleBatch.EPS_ID],
            batch[SampleBatch.UNROLL_ID],
            batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
            [batch[k] for k in state_keys],
            max_seq_len,
            dynamic_max=dynamic_max,
            shuffle=shuffle)
        for k, v in zip(feature_keys, feature_sequences):
            feed_dict[self._loss_input_dict[k]] = v
        for k, v in zip(state_keys, initial_states):
            feed_dict[self._loss_input_dict[k]] = v
        feed_dict[self._seq_lens] = seq_lens
        return feed_dict
Ejemplo n.º 3
0
 def testBatchId(self):
     eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
     batch_ids = [1, 1, 2, 2, 3, 3, 4, 4]
     agent_ids = [1, 1, 1, 1, 1, 1, 1, 1]
     f = [[101, 102, 103, 201, 202, 203, 204, 205],
          [[101], [102], [103], [201], [202], [203], [204], [205]]]
     s = [[209, 208, 207, 109, 108, 107, 106, 105]]
     _, _, seq_lens = chop_into_sequences(eps_ids, batch_ids, agent_ids, f,
                                          s, 4)
     self.assertEqual(seq_lens.tolist(), [2, 1, 1, 2, 2])
Ejemplo n.º 4
0
    def test_chop_into_sequences_long_seq(self):
        """Test pad_batch where episodes are longer than max_seq_len. The long
        seq should be split into two smaller seqs that are less than max_seq_len"""
        max_seq_len = 2
        # Input seq lens, corresponding to ep_ids, unroll_ids, etc.
        seq_lens = [2, 3, 1]  # noqa: F841
        ep_ids = [0, 0, 1, 1, 1, 2]
        unroll_ids = [2, 2, 3, 3, 3, 4]
        feats = [[1, 1, 2, 2, 2, 3]]
        # Input states, ie states[3] is the input state at
        # t = 3 and the output state at t = 2
        states = [[1, 2, 3, 4, 5, 6]]
        agent = [0, 0, 0, 0, 0, 0]
        f_pad, s_init, s_lens = chop_into_sequences(
            feature_columns=feats,
            state_columns=states,
            max_seq_len=max_seq_len,
            episode_ids=ep_ids,
            unroll_ids=unroll_ids,
            agent_indices=agent,
            dynamic_max=False,
        )
        expected_f_pad = [[1, 1, 2, 2, 2, 0, 3, 0]]
        expected_seq_lens = [2, 2, 1, 1]
        expected_states = [[1, 3, 5, 6]]
        check(f_pad, expected_f_pad)
        check(s_lens, expected_seq_lens)
        check(s_init, expected_states)

        # Try again with dynamic max
        f_pad, s_init, s_lens = chop_into_sequences(
            feature_columns=feats,
            state_columns=states,
            max_seq_len=max_seq_len,
            episode_ids=ep_ids,
            unroll_ids=unroll_ids,
            agent_indices=agent,
            dynamic_max=True,
        )
        check(f_pad, expected_f_pad)
        check(s_lens, expected_seq_lens)
        check(s_init, expected_states)
Ejemplo n.º 5
0
 def testDynamicMaxLen(self):
     eps_ids = [5, 2, 2]
     agent_ids = [2, 2, 2]
     f = [[1, 1, 1]]
     s = [[1, 1, 1]]
     f_pad, s_init, seq_lens = chop_into_sequences(eps_ids,
                                                   np.ones_like(eps_ids),
                                                   agent_ids, f, s, 4)
     self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]])
     self.assertEqual([s.tolist() for s in s_init], [[1, 1]])
     self.assertEqual(seq_lens.tolist(), [1, 2])
Ejemplo n.º 6
0
 def test_batch_id(self):
     eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
     batch_ids = [1, 1, 2, 2, 3, 3, 4, 4]
     agent_ids = [1, 1, 1, 1, 1, 1, 1, 1]
     f = [[101, 102, 103, 201, 202, 203, 204, 205],
          [[101], [102], [103], [201], [202], [203], [204], [205]]]
     s = [[209, 208, 207, 109, 108, 107, 106, 105]]
     _, _, seq_lens = chop_into_sequences(episode_ids=eps_ids,
                                          unroll_ids=batch_ids,
                                          agent_indices=agent_ids,
                                          feature_columns=f,
                                          state_columns=s,
                                          max_seq_len=4)
     self.assertEqual(seq_lens.tolist(), [2, 1, 1, 2, 2])
Ejemplo n.º 7
0
 def testMultiDim(self):
     eps_ids = [1, 1, 1]
     agent_ids = [1, 1, 1]
     obs = np.ones((84, 84, 4))
     f = [[obs, obs * 2, obs * 3]]
     s = [[209, 208, 207]]
     f_pad, s_init, seq_lens = chop_into_sequences(eps_ids,
                                                   np.ones_like(eps_ids),
                                                   agent_ids, f, s, 4)
     self.assertEqual([f.tolist() for f in f_pad], [
         np.array([obs, obs * 2, obs * 3]).tolist(),
     ])
     self.assertEqual([s.tolist() for s in s_init], [[209]])
     self.assertEqual(seq_lens.tolist(), [3])
Ejemplo n.º 8
0
 def test_dynamic_max_len(self):
     eps_ids = [5, 2, 2]
     agent_ids = [2, 2, 2]
     f = [[1, 1, 1]]
     s = [[1, 1, 1]]
     f_pad, s_init, seq_lens = chop_into_sequences(
         episode_ids=eps_ids,
         unroll_ids=np.ones_like(eps_ids),
         agent_indices=agent_ids,
         feature_columns=f,
         state_columns=s,
         max_seq_len=4)
     self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]])
     self.assertEqual([s.tolist() for s in s_init], [[1, 1]])
     self.assertEqual(seq_lens.tolist(), [1, 2])
Ejemplo n.º 9
0
 def testMultiAgent(self):
     eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
     agent_ids = [1, 1, 2, 1, 1, 2, 2, 3]
     f = [[101, 102, 103, 201, 202, 203, 204, 205],
          [[101], [102], [103], [201], [202], [203], [204], [205]]]
     s = [[209, 208, 207, 109, 108, 107, 106, 105]]
     f_pad, s_init, seq_lens = chop_into_sequences(eps_ids,
                                                   np.ones_like(eps_ids),
                                                   agent_ids,
                                                   f,
                                                   s,
                                                   4,
                                                   dynamic_max=False)
     self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1])
     self.assertEqual(len(f_pad[0]), 20)
     self.assertEqual(len(s_init[0]), 5)
Ejemplo n.º 10
0
 def testBasic(self):
     eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
     agent_ids = [1, 1, 1, 1, 1, 1, 1, 1]
     f = [[101, 102, 103, 201, 202, 203, 204, 205],
          [[101], [102], [103], [201], [202], [203], [204], [205]]]
     s = [[209, 208, 207, 109, 108, 107, 106, 105]]
     f_pad, s_init, seq_lens = chop_into_sequences(eps_ids,
                                                   np.ones_like(eps_ids),
                                                   agent_ids, f, s, 4)
     self.assertEqual([f.tolist() for f in f_pad], [
         [101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0],
         [[101], [102], [103], [0], [201], [202], [203], [204], [205], [0],
          [0], [0]],
     ])
     self.assertEqual([s.tolist() for s in s_init], [[209, 109, 105]])
     self.assertEqual(seq_lens.tolist(), [3, 4, 1])
Ejemplo n.º 11
0
 def test_multi_agent(self):
     eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
     agent_ids = [1, 1, 2, 1, 1, 2, 2, 3]
     f = [[101, 102, 103, 201, 202, 203, 204, 205],
          [[101], [102], [103], [201], [202], [203], [204], [205]]]
     s = [[209, 208, 207, 109, 108, 107, 106, 105]]
     f_pad, s_init, seq_lens = chop_into_sequences(
         episode_ids=eps_ids,
         unroll_ids=np.ones_like(eps_ids),
         agent_indices=agent_ids,
         feature_columns=f,
         state_columns=s,
         max_seq_len=4,
         dynamic_max=False)
     self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1])
     self.assertEqual(len(f_pad[0]), 20)
     self.assertEqual(len(s_init[0]), 5)
Ejemplo n.º 12
0
 def test_multi_dim(self):
     eps_ids = [1, 1, 1]
     agent_ids = [1, 1, 1]
     obs = np.ones((84, 84, 4))
     f = [[obs, obs * 2, obs * 3]]
     s = [[209, 208, 207]]
     f_pad, s_init, seq_lens = chop_into_sequences(
         episode_ids=eps_ids,
         unroll_ids=np.ones_like(eps_ids),
         agent_indices=agent_ids,
         feature_columns=f,
         state_columns=s,
         max_seq_len=4)
     self.assertEqual([f.tolist() for f in f_pad], [
         np.array([obs, obs * 2, obs * 3]).tolist(),
     ])
     self.assertEqual([s.tolist() for s in s_init], [[209]])
     self.assertEqual(seq_lens.tolist(), [3])
Ejemplo n.º 13
0
 def test_basic(self):
     eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
     agent_ids = [1, 1, 1, 1, 1, 1, 1, 1]
     f = [[101, 102, 103, 201, 202, 203, 204, 205],
          [[101], [102], [103], [201], [202], [203], [204], [205]]]
     s = [[209, 208, 207, 109, 108, 107, 106, 105]]
     f_pad, s_init, seq_lens = chop_into_sequences(
         episode_ids=eps_ids,
         unroll_ids=np.ones_like(eps_ids),
         agent_indices=agent_ids,
         feature_columns=f,
         state_columns=s,
         max_seq_len=4)
     self.assertEqual([f.tolist() for f in f_pad], [
         [101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0],
         [[101], [102], [103], [0], [201], [202], [203], [204], [205], [0],
          [0], [0]],
     ])
     self.assertEqual([s.tolist() for s in s_init], [[209, 109, 105]])
     self.assertEqual(seq_lens.tolist(), [3, 4, 1])
Ejemplo n.º 14
0
    def _get_loss_inputs_dict(self, batch, shuffle):
        """Return a feed dict from a batch.

        Arguments:
            batch (SampleBatch): batch of data to derive inputs from
            shuffle (bool): whether to shuffle batch sequences. Shuffle may
                be done in-place. This only makes sense if you're further
                applying minibatch SGD after getting the outputs.

        Returns:
            feed dict of data
        """

        feed_dict = {}
        if self._batch_divisibility_req > 1:
            meets_divisibility_reqs = (
                len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req
                == 0
                and max(batch[SampleBatch.AGENT_INDEX]) == 0)  # not multiagent
        else:
            meets_divisibility_reqs = True

        # Simple case: not RNN nor do we need to pad
        if not self._state_inputs and meets_divisibility_reqs:
            if shuffle:
                batch.shuffle()
            for k, ph in self._loss_inputs:
                feed_dict[ph] = batch[k]
            return feed_dict

        if self._state_inputs:
            max_seq_len = self._max_seq_len
            dynamic_max = True
        else:
            max_seq_len = self._batch_divisibility_req
            dynamic_max = False

        # RNN or multi-agent case
        feature_keys = [k for k, v in self._loss_inputs]
        state_keys = [
            "state_in_{}".format(i) for i in range(len(self._state_inputs))
        ]
        feature_sequences, initial_states, seq_lens = chop_into_sequences(
            batch[SampleBatch.EPS_ID],
            batch[SampleBatch.UNROLL_ID],
            batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
            [batch[k] for k in state_keys],
            max_seq_len,
            dynamic_max=dynamic_max,
            shuffle=shuffle)
        for k, v in zip(feature_keys, feature_sequences):
            feed_dict[self._loss_input_dict[k]] = v
        for k, v in zip(state_keys, initial_states):
            feed_dict[self._loss_input_dict[k]] = v
        feed_dict[self._seq_lens] = seq_lens

        if log_once("rnn_feed_dict"):
            logger.info("Padded input for RNN:\n\n{}\n".format(
                summarize({
                    "features": feature_sequences,
                    "initial_states": initial_states,
                    "seq_lens": seq_lens,
                    "max_seq_len": max_seq_len,
                })))
        return feed_dict
Ejemplo n.º 15
0
    def learn_on_batch(self, samples):
        obs_batch, action_mask = self._unpack_observation(
            samples[SampleBatch.CUR_OBS])
        next_obs_batch, next_action_mask = self._unpack_observation(
            samples[SampleBatch.NEXT_OBS])
        group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])

        # These will be padded to shape [B * T, ...]
        [rew, action_mask, next_action_mask, act, dones, obs, next_obs], \
            initial_states, seq_lens = \
            chop_into_sequences(
                samples[SampleBatch.EPS_ID],
                samples[SampleBatch.UNROLL_ID],
                samples[SampleBatch.AGENT_INDEX], [
                    group_rewards, action_mask, next_action_mask,
                    samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES],
                    obs_batch, next_obs_batch
                ],
                [samples["state_in_{}".format(k)]
                 for k in range(len(self.get_initial_state()))],
                max_seq_len=self.config["model"]["max_seq_len"],
                dynamic_max=True)
        B, T = len(seq_lens), max(seq_lens)

        def to_batches(arr):
            new_shape = [B, T] + list(arr.shape[1:])
            return th.from_numpy(np.reshape(arr, new_shape))

        rewards = to_batches(rew).float()
        actions = to_batches(act).long()
        obs = to_batches(obs).reshape([B, T, self.n_agents,
                                       self.obs_size]).float()
        action_mask = to_batches(action_mask)
        next_obs = to_batches(next_obs).reshape(
            [B, T, self.n_agents, self.obs_size]).float()
        next_action_mask = to_batches(next_action_mask)

        # TODO(ekl) this treats group termination as individual termination
        terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand(
            B, T, self.n_agents)

        # Create mask for where index is < unpadded sequence length
        filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) <
                  np.expand_dims(seq_lens, 1)).astype(np.float32)
        mask = th.from_numpy(filled).unsqueeze(2).expand(B, T, self.n_agents)

        # Compute loss
        loss_out, mask, masked_td_error, chosen_action_qvals, targets = \
            self.loss(rewards, actions, terminated, mask, obs,
                      next_obs, action_mask, next_action_mask)

        # Optimise
        self.optimiser.zero_grad()
        loss_out.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(
            self.params, self.config["grad_norm_clipping"])
        self.optimiser.step()

        mask_elems = mask.sum().item()
        stats = {
            "loss":
            loss_out.item(),
            "grad_norm":
            grad_norm if isinstance(grad_norm, float) else grad_norm.item(),
            "td_error_abs":
            masked_td_error.abs().sum().item() / mask_elems,
            "q_taken_mean":
            (chosen_action_qvals * mask).sum().item() / mask_elems,
            "target_mean": (targets * mask).sum().item() / mask_elems,
        }
        return {LEARNER_STATS_KEY: stats}
Ejemplo n.º 16
0
    def learn_on_batch(self, samples):
        obs_batch, action_mask, env_global_state = self._unpack_observation(
            samples[SampleBatch.CUR_OBS])
        (next_obs_batch, next_action_mask,
         next_env_global_state) = self._unpack_observation(
             samples[SampleBatch.NEXT_OBS])
        group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])

        input_list = [
            group_rewards, action_mask, next_action_mask,
            samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES],
            obs_batch, next_obs_batch
        ]
        if self.has_env_global_state:
            input_list.extend([env_global_state, next_env_global_state])

        output_list, _, seq_lens = \
            chop_into_sequences(
                samples[SampleBatch.EPS_ID],
                samples[SampleBatch.UNROLL_ID],
                samples[SampleBatch.AGENT_INDEX],
                input_list,
                [],  # RNN states not used here
                max_seq_len=self.config["model"]["max_seq_len"],
                dynamic_max=True)
        # These will be padded to shape [B * T, ...]
        if self.has_env_global_state:
            (rew, action_mask, next_action_mask, act, dones, obs, next_obs,
             env_global_state, next_env_global_state) = output_list
        else:
            (rew, action_mask, next_action_mask, act, dones, obs,
             next_obs) = output_list
        B, T = len(seq_lens), max(seq_lens)

        def to_batches(arr, dtype):
            new_shape = [B, T] + list(arr.shape[1:])
            return torch.as_tensor(
                np.reshape(arr, new_shape), dtype=dtype, device=self.device)

        rewards = to_batches(rew, torch.float)
        actions = to_batches(act, torch.long)
        obs = to_batches(obs, torch.float).reshape(
            [B, T, self.n_agents, self.obs_size])
        action_mask = to_batches(action_mask, torch.float)
        next_obs = to_batches(next_obs, torch.float).reshape(
            [B, T, self.n_agents, self.obs_size])
        next_action_mask = to_batches(next_action_mask, torch.float)
        if self.has_env_global_state:
            env_global_state = to_batches(env_global_state, torch.float)
            next_env_global_state = to_batches(next_env_global_state,
                                               torch.float)

        # TODO(ekl) this treats group termination as individual termination
        terminated = to_batches(dones, torch.float).unsqueeze(2).expand(
            B, T, self.n_agents)

        # Create mask for where index is < unpadded sequence length
        filled = np.reshape(
            np.tile(np.arange(T, dtype=np.float32), B),
            [B, T]) < np.expand_dims(seq_lens, 1)
        mask = torch.as_tensor(
            filled, dtype=torch.float, device=self.device).unsqueeze(2).expand(
                B, T, self.n_agents)

        # Compute loss
        loss_out, mask, masked_td_error, chosen_action_qvals, targets = (
            self.loss(rewards, actions, terminated, mask, obs, next_obs,
                      action_mask, next_action_mask, env_global_state,
                      next_env_global_state))

        # Optimise
        self.optimiser.zero_grad()
        loss_out.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.params, self.config["grad_norm_clipping"])
        self.optimiser.step()

        mask_elems = mask.sum().item()
        stats = {
            "loss": loss_out.item(),
            "grad_norm": grad_norm
            if isinstance(grad_norm, float) else grad_norm.item(),
            "td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
            "q_taken_mean": (chosen_action_qvals * mask).sum().item() /
            mask_elems,
            "target_mean": (targets * mask).sum().item() / mask_elems,
        }
        return {LEARNER_STATS_KEY: stats}
Ejemplo n.º 17
0
    def _get_loss_inputs_dict(self, batch, neighbor_batch_dic, shuffle):
        """Return a feed dict from a batch.

        Arguments:
            batch (SampleBatch): batch of data to derive inputs from
            neighbor_batch_dic (dict, SampleBatch): batch of data for neighbor of the main policy
            shuffle (bool): whether to shuffle batch sequences. Shuffle may
                be done in-place. This only makes sense if you're further
                applying minibatch SGD after getting the outputs.

        Returns:
            feed dict of data
        """

        feed_dict = {}
        if self._batch_divisibility_req > 1:
            meets_divisibility_reqs = (
                len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req
                == 0
                and max(batch[SampleBatch.AGENT_INDEX]) == 0)  # not multiagent
        else:
            meets_divisibility_reqs = True

        neighbor_list = [None] * 5
        neighbor_count = 0
        for k in neighbor_batch_dic:
            neighbor_list[neighbor_count] = k
            neighbor_count += 1

        tmp_dic = {}
        # Simple case: not RNN nor do we need to pad
        if not self._state_inputs and meets_divisibility_reqs:
            if shuffle:
                batch.shuffle()
            for k, ph in self._loss_inputs:
                '''
                For Attention
                这里是用于Q target的replay buffer, 从SampleBatch中整理出 neighbor_obs的信息
                '''
                if 'neighbor_obs' in k:
                    tmp_dic[ph] = {}
                    feed_dict[ph] = []
                    # neighbor_id = neighbor_list[int(k.split('_')[2])]
                    # if neighbor_id is None:
                    #     continue
                    # else:
                    for neighbor_id in neighbor_list:
                        tmp_dic[ph][neighbor_id] = neighbor_batch_dic[
                            neighbor_id]['obs']
                    neighbor_id = neighbor_list[0]
                    # [neighbor, batch, feather] -> [batch, neighbor, feather]
                    for batch_item in range(len(tmp_dic[ph][neighbor_id])):
                        feed_dict[ph].append([])
                        for neighbor_id in neighbor_list:
                            feed_dict[ph][batch_item].append(
                                tmp_dic[ph][neighbor_id][batch_item])
                    feed_dict[ph] = np.array(feed_dict[ph])
                    # feed_dict[ph] = Lambda(lambda x: K.permute_dimensions(x, (0, 2, 1)))(tmp_dic[ph])
                # ----------------------------------------------------------------
                else:
                    feed_dict[ph] = batch[k]
            return feed_dict

        if self._state_inputs:
            max_seq_len = self._max_seq_len
            dynamic_max = True
        else:
            max_seq_len = self._batch_divisibility_req
            dynamic_max = False

        # RNN or multi-agent case
        feature_keys = [k for k, v in self._loss_inputs]
        state_keys = [
            "state_in_{}".format(i) for i in range(len(self._state_inputs))
        ]
        feature_sequences, initial_states, seq_lens = chop_into_sequences(
            batch[SampleBatch.EPS_ID],
            batch[SampleBatch.UNROLL_ID],
            batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
            [batch[k] for k in state_keys],
            max_seq_len,
            dynamic_max=dynamic_max,
            shuffle=shuffle)
        for k, v in zip(feature_keys, feature_sequences):
            feed_dict[self._loss_input_dict[k]] = v
        for k, v in zip(state_keys, initial_states):
            feed_dict[self._loss_input_dict[k]] = v
        feed_dict[self._seq_lens] = seq_lens

        if log_once("rnn_feed_dict"):
            logger.info("Padded input for RNN:\n\n{}\n".format(
                summarize({
                    "features": feature_sequences,
                    "initial_states": initial_states,
                    "seq_lens": seq_lens,
                    "max_seq_len": max_seq_len,
                })))
        return feed_dict
Ejemplo n.º 18
0
def build_q_losses_recurrent(policy, model, dist_class, samples):
    # Observations
    obs_batch, action_mask, _, _ = unpack_train_observations(
        policy, samples[SampleBatch.CUR_OBS], policy.device)
    next_obs_batch, next_action_mask, _, _ = unpack_train_observations(
        policy, samples[SampleBatch.NEXT_OBS], policy.device)
    rewards = samples[SampleBatch.REWARDS]

    # Obtain sequences
    input_list = [
        rewards, action_mask, next_action_mask, samples[SampleBatch.ACTIONS],
        samples[SampleBatch.DONES], obs_batch, next_obs_batch
    ]
    output_list, _, seq_lens = \
        chop_into_sequences(
            episode_ids=samples[SampleBatch.EPS_ID],
            unroll_ids=samples[SampleBatch.UNROLL_ID],
            agent_indices=samples[SampleBatch.AGENT_INDEX],
            feature_columns=input_list,
            state_columns=[],  # RNN states not used here
            max_seq_len=policy.config["model"]["max_seq_len"],
            dynamic_max=True)
    (rew, action_mask, next_action_mask, act, dones, obs,
     next_obs) = output_list

    B, T = len(seq_lens), max(seq_lens)

    def to_batches(arr, dtype):
        new_shape = [B, T] + list(arr.shape[1:])
        return torch.as_tensor(np.reshape(arr, new_shape),
                               dtype=dtype,
                               device=policy.device)

    rewards = to_batches(rew, torch.float)
    actions = to_batches(act, torch.long)
    obs = to_batches(obs, torch.float).reshape([B, T, -1])
    action_mask = to_batches(action_mask, torch.float)
    next_obs = to_batches(next_obs, torch.float).reshape([B, T, -1])
    next_action_mask = to_batches(next_action_mask, torch.float)
    terminated = to_batches(dones, torch.float).unsqueeze(2).expand(B, T, 1)
    filled = np.reshape(np.tile(np.arange(T, dtype=np.float32), B),
                        [B, T]) < np.expand_dims(seq_lens, 1)
    mask = torch.as_tensor(filled, dtype=torch.float,
                           device=policy.device).unsqueeze(2).expand(B, T, 1)

    q_t = compute_sequence_q_values(policy,
                                    policy.q_model,
                                    obs,
                                    explore=False,
                                    is_training=True)
    chosen_action_qvals = torch.gather(q_t,
                                       dim=-1,
                                       index=actions.unsqueeze(-1))

    q_tp1 = compute_sequence_q_values(policy,
                                      policy.target_q_model,
                                      next_obs,
                                      explore=False,
                                      is_training=True)
    ignore_action_tp1 = (next_action_mask == 0) & (mask == 1)
    q_tp1[ignore_action_tp1] = -np.inf
    target_max_qvals = q_tp1.max(dim=-1)[0]
    targets = rewards.squeeze(-1) + policy.config["gamma"] * (
        1 - terminated.squeeze(-1)) * target_max_qvals
    td_error = (chosen_action_qvals.squeeze(-1) - targets.detach())
    mask = mask.squeeze(-1).expand_as(td_error)
    masked_td_error = td_error * mask
    loss = (masked_td_error**2).sum() / mask.sum()
    policy.td_error = td_error
    return loss