Ejemplo n.º 1
0
    def __getitem__(self, index: Union[slice, int, np.integer,
                                       np.ndarray]) -> Batch:
        """Return a data batch: self[index]. If stack_num is set to be > 0,
        return the stacked obs and obs_next with shape [batch, len, ...].
        """

        # protect_range = [(self._index - self.protect_num) % self._size, self._index] # [a, b)

        if self.protect_num > 0:
            qualified = ((self._index - index) % self._size) > self.protect_num
            index_jump = (index + self.protect_num) % self._size
            index = (index * qualified + (1 - qualified) * index_jump).astype(
                np.int)

            protect_index = np.random.randint(0,
                                              self.protect_num,
                                              size=index.shape)
            protect_index = (self._index - protect_index - 1) % self._size

            if self.with_pt:
                index = np.concatenate([index, protect_index], 0)

        kwargs = dict()
        for i in range(1, self.num):
            kwargs[''.join(['act'] + ['_next'] * i)] = self.act[(index + i) %
                                                                self._size]
            kwargs[''.join(['rew'] + ['_next'] * i)] = self.rew[(index + i) %
                                                                self._size]
            kwargs[''.join(['obs_next'] + ['_next'] * i)] = self.get(
                (index + i) % self._size, 'obs_next')
            kwargs[''.join(['done'] + ['_next'] * i)] = self.done[(index + i) %
                                                                  self._size]
            kwargs[''.join(['done_bk'] + ['_next'] *
                           (i - 1))] = self.done_bk[(index + i - 1) %
                                                    self._size]

        ret = Batch(obs=self.get(index, 'obs'),
                    obs_next=self.get(index, 'obs_next'),
                    act=self.act[index],
                    rew=self.rew[index],
                    done=self.done[index],
                    info=self.get(index, 'info'),
                    policy=self.get(index, 'policy'),
                    **kwargs)

        if self.log_prob is not None:
            ret.log_prob = self.log_prob[index]
        if self.alpha is not None:
            ret.alpha = self.alpha[index]
        # if hasattr(self, 'int_rew'):
        #     ret.int_rew = self.int_rew[index]

        if self._ens_num is not None:
            ret.mask = self.mask[index]
        if self._ngu:
            ret.erew = self.erew[index]
            ret.goal = self.goal[index]
        if self._rand2:
            ret.irew = self.irew[index]
        return ret
Ejemplo n.º 2
0
    def __getitem__(self, index: Union[slice, int, np.integer,
                                       np.ndarray]) -> Batch:
        """Return a data batch: self[index]. If stack_num is set to be > 0,
        return the stacked obs and obs_next with shape [batch, len, ...].
        """
        ret = Batch(obs=self.get(index, 'obs'),
                    act=self.act[index],
                    rew=self.rew[index],
                    done=self.done[index],
                    done_bk=self.done_bk[index],
                    obs_next=self.get(index, 'obs_next'),
                    info=self.get(index, 'info'),
                    policy=self.get(index, 'policy'))

        if self.log_prob is not None:
            ret.log_prob = self.log_prob[index]
        if self.alpha is not None:
            ret.alpha = self.alpha[index]
        # if hasattr(self, 'int_rew'):
        #     ret.int_rew = self.int_rew[index]

        if self._ens_num is not None:
            ret.mask = self.mask[index]
        if self._ngu:
            ret.erew = self.erew[index]
            ret.goal = self.goal[index]
        if self._rand2:
            ret.irew = self.irew[index]
        return ret
Ejemplo n.º 3
0
    def __getitem__(self, index: Union[slice, int, np.integer,
                                       np.ndarray]) -> Batch:
        """Return a data batch: self[index]. If stack_num is set to be > 0,
        return the stacked obs and obs_next with shape [batch, len, ...].
        """
        kwargs = dict()
        for i in range(1, self.num):
            kwargs[''.join(['act'] + ['_next'] * i)] = self.act[(index + i) %
                                                                self._size]
            kwargs[''.join(['rew'] + ['_next'] * i)] = self.rew[(index + i) %
                                                                self._size]
            kwargs[''.join(['obs_next'] + ['_next'] * i)] = self.get(
                (index + i) % self._size, 'obs_next')
            kwargs[''.join(['done'] + ['_next'] * i)] = self.done[(index + i) %
                                                                  self._size]
            kwargs[''.join(['done_bk'] + ['_next'] *
                           (i - 1))] = self.done_bk[(index + i - 1) %
                                                    self._size]

        ret = Batch(obs=self.get(index, 'obs'),
                    obs_next=self.get(index, 'obs_next'),
                    act=self.act[index],
                    rew=self.rew[index],
                    done=self.done[index],
                    info=self.get(index, 'info'),
                    policy=self.get(index, 'policy'),
                    **kwargs)

        if self.log_prob is not None:
            ret.log_prob = self.log_prob[index]
        if self.alpha is not None:
            ret.alpha = self.alpha[index]
        # if hasattr(self, 'int_rew'):
        #     ret.int_rew = self.int_rew[index]

        if self._ens_num is not None:
            ret.mask = self.mask[index]
        if self._ngu:
            ret.erew = self.erew[index]
            ret.goal = self.goal[index]
        if self._rand2:
            ret.irew = self.irew[index]
        return ret