Exemplo n.º 1
0
    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        BATCH.reward += BATCH.reward_offset

        BATCH.last_options = int2one_hot(BATCH.last_options, self.options_num)
        BATCH.options = int2one_hot(BATCH.options, self.options_num)
        value = self._get_value(BATCH.obs_[-1],
                                BATCH.options[-1],
                                rnncs=self.rnncs)
        BATCH.discounted_reward = discounted_sum(BATCH.reward,
                                                 self.gamma,
                                                 BATCH.done,
                                                 BATCH.begin_mask,
                                                 init_value=value)
        td_error = calculate_td_error(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            value=BATCH.value,
            next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]),
                                      0))
        BATCH.gae_adv = discounted_sum(td_error,
                                       self.lambda_ * self.gamma,
                                       BATCH.done,
                                       BATCH.begin_mask,
                                       init_value=0.,
                                       normalize=True)
        return BATCH
Exemplo n.º 2
0
 def cal_td_error(self, gamma, init_value):
     '''
     计算td error
     TD = r + gamma * (1- done) * v(s') - v(s)
     '''
     assert 'value' in self.data_buffer.keys(
     ), "assert 'value' in self.data_buffer.keys()"
     self.data_buffer['td_error'] = list(
         calculate_td_error(
             self.data_buffer['reward'],
             gamma,
             self.data_buffer['done'],
             self.data_buffer['value'],
             self.data_buffer['value'][1:] + [init_value],
         ))