Ejemplo n.º 1
0
    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        BATCH.reward += BATCH.reward_offset

        BATCH.last_options = int2one_hot(BATCH.last_options, self.options_num)
        BATCH.options = int2one_hot(BATCH.options, self.options_num)
        value = self._get_value(BATCH.obs_[-1],
                                BATCH.options[-1],
                                rnncs=self.rnncs)
        BATCH.discounted_reward = discounted_sum(BATCH.reward,
                                                 self.gamma,
                                                 BATCH.done,
                                                 BATCH.begin_mask,
                                                 init_value=value)
        td_error = calculate_td_error(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            value=BATCH.value,
            next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]),
                                      0))
        BATCH.gae_adv = discounted_sum(td_error,
                                       self.lambda_ * self.gamma,
                                       BATCH.done,
                                       BATCH.begin_mask,
                                       init_value=0.,
                                       normalize=True)
        return BATCH
Ejemplo n.º 2
0
 def _preprocess_BATCH(self, BATCH_DICT):  # [B, *] or [T, B, *]
     for id in self.agent_ids:
         if not self.is_continuouss[id]:
             shape = BATCH_DICT[id].action.shape
             # [T, B, 1] or [T, B] => [T, B, N]
             BATCH_DICT[id].action = int2one_hot(
                 BATCH_DICT[id].action, self.a_dims[id])
     for i, id in enumerate(self.agent_ids):
         other, other_ = None, None
         if self._obs_with_pre_action:
             other = np.concatenate((
                 np.zeros_like(BATCH_DICT[id].action[:1]),
                 BATCH_DICT[id].action[:-1]
             ), 0)
             other_ = BATCH_DICT[id].action
         if self._obs_with_agent_id:
             _id_onehot = int2one_hot(
                 np.full(BATCH_DICT[id].action.shape[:-1], i), self.n_agents_percopy)
             if other is not None:
                 other = np.concatenate((
                     other,
                     _id_onehot
                 ), -1)
                 other_ = np.concatenate((
                     other_,
                     _id_onehot
                 ), -1)
             else:
                 other, other_ = _id_onehot, _id_onehot
         if self._obs_with_pre_action or self._obs_with_agent_id:
             BATCH_DICT[id].obs.update(other=other)
             BATCH_DICT[id].obs_.update(other=other_)
     return BATCH_DICT
Ejemplo n.º 3
0
 def _preprocess_obs(self, obs):
     if self._obs_with_pre_action:
         if not self.is_continuous:
             _pre_act = int2one_hot(self._pre_act, self.a_dim)
         else:
             _pre_act = self._pre_act
         obs.update(other=_pre_act)
     return obs
Ejemplo n.º 4
0
 def convert_action2one_hot(self, a_counts):
     '''
     用于在训练前将buffer中的离散动作的索引转换为one_hot类型
     '''
     if 'a' in self.buffer.keys():
         self.buffer['a'] = [
             int2one_hot(a.astype(np.int32), a_counts)
             for a in self.buffer['a']
         ]
Ejemplo n.º 5
0
 def convert_action2one_hot(self, a_counts):
     '''
     用于在训练前将buffer中的离散动作的索引转换为one_hot类型
     '''
     assert 'action' in self.data_buffer.keys(
     ), "assert 'action' in self.data_buffer.keys()"
     self.data_buffer['action'] = [
         int2one_hot(a.astype(np.int32), a_counts)
         for a in self.data_buffer['action']
     ]
Ejemplo n.º 6
0
 def _data_process2dict(self, exps: BatchExperiences) -> BatchExperiences:
     # TODO 优化
     if not self.is_continuous:
         assert 'action' in exps._fields, "assert 'action' in exps._fields"
         exps = exps._replace(action=int2one_hot(exps.action.astype(np.int32), self.a_dim))
     assert 'obs' in exps._fields and 'obs_' in exps._fields, "'obs' in exps._fields and 'obs_' in exps._fields"
     # exps = exps._replace(
     #     obs=exps.obs._replace(vector=self.normalize_vector_obs()),
     #     obs_=exps.obs_._replace(vector=self.normalize_vector_obs()))
     return NamedTupleStaticClass.data_convert(self.data_convert, exps)
Ejemplo n.º 7
0
 def _preprocess_obs(self, obs: Dict):
     for i, id in enumerate(self.agent_ids):
         other = None
         if self._obs_with_pre_action:
             if not self.is_continuouss[id]:
                 other = int2one_hot(self._pre_acts[id], self.a_dims[id])
             else:
                 other = self._pre_acts[id]
         if self._obs_with_agent_id:
             _id_onehot = int2one_hot(np.full(self.n_copies, i), self.n_agents_percopy)
             if other is not None:
                 other = np.concatenate((
                     other,
                     _id_onehot
                 ), -1)
             else:
                 other = _id_onehot
         if self._obs_with_pre_action or self._obs_with_agent_id:
             obs[id].update(other=other)
     return obs
Ejemplo n.º 8
0
 def _preprocess_BATCH(self, BATCH):  # [T, B, *]
     if not self.is_continuous:
         # [T, B, 1] or [T, B] => [T, B, N]
         BATCH.action = int2one_hot(BATCH.action, self.a_dim)
     if self._obs_with_pre_action:
         BATCH.obs.update(other=np.concatenate(
             (
                 np.zeros_like(BATCH.action[:1]),  # TODO: improve
                 BATCH.action[:-1]),
             0))
         BATCH.obs_.update(other=BATCH.action)
     return BATCH
Ejemplo n.º 9
0
    def _data_process2dict(self, data, data_name_list):
        if not self.is_continuous and 'a' in data_name_list:
            a_idx = data_name_list.index('a')
            data[a_idx] = int2one_hot(data[a_idx].astype(np.int32), self.a_dim)
        if 's' in data_name_list:
            s_idx = data_name_list.index('s')
            data[s_idx] = self.normalize_vector_obs(data[s_idx])
        if 's_' in data_name_list:
            s_idx = data_name_list.index('s_')
            data[s_idx] = self.normalize_vector_obs(data[s_idx])

        return dict([
            [n, d] for n, d in zip(data_name_list, list(map(self.data_convert, data)))
        ])
Ejemplo n.º 10
0
 def get_transitions(self,
                     databuffer,
                     data_name_list=['s', 'a', 'r', 's_', 'done']):
     '''
     TODO: Annotation
     '''
     exps = databuffer.sample()  # 经验池取数据
     if not self.is_continuous:
         assert 'action' in exps._fields, "assert 'action' in exps._fields"
         a = exps.action.astype(np.int32)
         pre_shape = a.shape
         a = a.reshape(-1)
         a = int2one_hot(a, self.a_dim)
         a = a.reshape(pre_shape + (-1, ))
         exps = exps._replace(action=a)
     return NamedTupleStaticClass.data_convert(self.data_convert, exps)
Ejemplo n.º 11
0
 def get_transitions(self,
                     databuffer,
                     data_name_list=['s', 'a', 'r', 's_', 'done']):
     '''
     TODO: Annotation
     '''
     data = databuffer.sample()  # 经验池取数据
     if not self.is_continuous and 'a' in data_name_list:
         a_idx = data_name_list.index('a')
         a = data[a_idx].astype(np.int32)
         pre_shape = a.shape
         a = a.reshape(-1)
         a = int2one_hot(a, self.a_dim)
         a = a.reshape(pre_shape + (-1, ))
         data[a_idx] = a
     return dict([[
         n, d
     ] for n, d in zip(data_name_list, list(map(self.data_convert, data)))])
Ejemplo n.º 12
0
        def get_transitions(
            self,
            data_name_list: List[str] = [
                's', 'visual_s', 'a', 'r', 's_', 'visual_s_', 'done'
            ]
        ) -> Dict:
            '''
            TODO: Annotation
            '''
            data = self.data.sample()  # 经验池取数据
            if not self.is_continuous and 'a' in data_name_list:
                a_idx = data_name_list.index('a')
                data[a_idx] = int2one_hot(data[a_idx].astype(np.int32),
                                          self.a_dim)
            if 's' in data_name_list:
                s_idx = data_name_list.index('s')
                data[s_idx] = self.normalize_vector_obs(data[s_idx])
            if 's_' in data_name_list:
                s_idx = data_name_list.index('s_')
                data[s_idx] = self.normalize_vector_obs(data[s_idx])

            return dict([[n, d] for n, d in zip(
                data_name_list, list(map(self.data_convert, data)))])
Ejemplo n.º 13
0
 def _preprocess_BATCH(self, BATCH):  # [T, B, *]
     BATCH = super()._preprocess_BATCH(BATCH)
     BATCH.last_options = int2one_hot(BATCH.last_options, self.options_num)
     BATCH.options = int2one_hot(BATCH.options, self.options_num)
     return BATCH