def get_step(self, **kwargs):
     if 'done' in kwargs:
         dones = kwargs["done"]
         self.cur_traj_step = np.where(dones, 0, self.cur_traj_step + 1)
         self.cur_done_cnt = np.where(dones, self.cur_done_cnt + 1, self.cur_done_cnt)
         return Batch(step=np.array([0] * self.n) if \
                      self.bk_step else self.cur_traj_step,
                      done_cnt=self.cur_done_cnt)
     return Batch()
    def __getitem__(self, index: Union[slice, int, np.integer,
                                       np.ndarray]) -> Batch:
        """Return a data batch: self[index]. If stack_num is set to be > 0,
        return the stacked obs and obs_next with shape [batch, len, ...].
        """
        ret = Batch(obs=self.get(index, 'obs'),
                    act=self.act[index],
                    rew=self.rew[index],
                    done=self.done[index],
                    done_bk=self.done_bk[index],
                    obs_next=self.get(index, 'obs_next'),
                    info=self.get(index, 'info'),
                    policy=self.get(index, 'policy'))

        if self.log_prob is not None:
            ret.log_prob = self.log_prob[index]
        if self.alpha is not None:
            ret.alpha = self.alpha[index]
        # if hasattr(self, 'int_rew'):
        #     ret.int_rew = self.int_rew[index]

        if self._ens_num is not None:
            ret.mask = self.mask[index]
        if self._ngu:
            ret.erew = self.erew[index]
            ret.goal = self.goal[index]
        if self._rand2:
            ret.irew = self.irew[index]
        return ret
Exemple #3
0
    def sample(self, batch_size: int = 0, importance_sample: bool = True):
        """Get a random sample from buffer with priority probability. \
        Return all the data in the buffer if batch_size is ``0``.

        :return: Sample data and its corresponding index inside the buffer.
        """
        if batch_size > 0 and batch_size <= self._size:
            # Multiple sampling of the same sample
            # will cause weight update conflict
            indice = np.random.choice(self._size,
                                      batch_size,
                                      p=(self.weight /
                                         self.weight.sum())[:self._size],
                                      replace=False)
            # self._weight_sum is not work for the accuracy issue
            # p=(self.weight/self._weight_sum)[:self._size], replace=False)
        elif batch_size == 0:
            indice = np.concatenate([
                np.arange(self._index, self._size),
                np.arange(0, self._index),
            ])
        else:
            # if batch_size larger than len(self),
            # it will lead to a bug in update weight
            raise ValueError("batch_size should be less than len(self)")
        batch = self[indice]
        if importance_sample:
            impt_weight = Batch(
                impt_weight=1 /
                np.power(self._size *
                         (batch.weight / self._weight_sum), self._beta))
            batch.append(impt_weight)
        self._check_weight_sum()
        return batch, indice
Exemple #4
0
def to_torch(
    x: Any,
    dtype: Optional[torch.dtype] = None,
    device: Union[str, int, torch.device] = "cpu",
) -> Union[Batch, torch.Tensor]:
    """Return an object without np.ndarray."""
    if isinstance(x, np.ndarray) and issubclass(
            x.dtype.type, (np.bool_, np.number)):  # most often case
        x = torch.from_numpy(x).to(device)  # type: ignore
        if dtype is not None:
            x = x.type(dtype)
        return x
    elif isinstance(x, torch.Tensor):  # second often case
        if dtype is not None:
            x = x.type(dtype)
        return x.to(device)  # type: ignore
    elif isinstance(x, (np.number, np.bool_, Number)):
        return to_torch(np.asanyarray(x), dtype, device)
    elif isinstance(x, (dict, Batch)):
        x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x)
        x.to_torch(dtype, device)
        return x
    elif isinstance(x, (list, tuple)):
        return to_torch(_parse_value(x), dtype, device)
    else:  # fallback
        raise TypeError(f"object {x} cannot be converted to torch.")
    def __getitem__(self, index: Union[slice, int, np.integer,
                                       np.ndarray]) -> Batch:
        """Return a data batch: self[index]. If stack_num is set to be > 0,
        return the stacked obs and obs_next with shape [batch, len, ...].
        """

        # protect_range = [(self._index - self.protect_num) % self._size, self._index] # [a, b)

        if self.protect_num > 0:
            qualified = ((self._index - index) % self._size) > self.protect_num
            index_jump = (index + self.protect_num) % self._size
            index = (index * qualified + (1 - qualified) * index_jump).astype(
                np.int)

            protect_index = np.random.randint(0,
                                              self.protect_num,
                                              size=index.shape)
            protect_index = (self._index - protect_index - 1) % self._size

            if self.with_pt:
                index = np.concatenate([index, protect_index], 0)

        kwargs = dict()
        for i in range(1, self.num):
            kwargs[''.join(['act'] + ['_next'] * i)] = self.act[(index + i) %
                                                                self._size]
            kwargs[''.join(['rew'] + ['_next'] * i)] = self.rew[(index + i) %
                                                                self._size]
            kwargs[''.join(['obs_next'] + ['_next'] * i)] = self.get(
                (index + i) % self._size, 'obs_next')
            kwargs[''.join(['done'] + ['_next'] * i)] = self.done[(index + i) %
                                                                  self._size]
            kwargs[''.join(['done_bk'] + ['_next'] *
                           (i - 1))] = self.done_bk[(index + i - 1) %
                                                    self._size]

        ret = Batch(obs=self.get(index, 'obs'),
                    obs_next=self.get(index, 'obs_next'),
                    act=self.act[index],
                    rew=self.rew[index],
                    done=self.done[index],
                    info=self.get(index, 'info'),
                    policy=self.get(index, 'policy'),
                    **kwargs)

        if self.log_prob is not None:
            ret.log_prob = self.log_prob[index]
        if self.alpha is not None:
            ret.alpha = self.alpha[index]
        # if hasattr(self, 'int_rew'):
        #     ret.int_rew = self.int_rew[index]

        if self._ens_num is not None:
            ret.mask = self.mask[index]
        if self._ngu:
            ret.erew = self.erew[index]
            ret.goal = self.goal[index]
        if self._rand2:
            ret.irew = self.irew[index]
        return ret
Exemple #6
0
 def __getitem__(self, index):
     return Batch(obs=self.obs[index],
                  act=self.act[index],
                  rew=self.rew[index],
                  done=self.done[index],
                  obs_next=self.obs_next[index],
                  info=self.info[index])
    def add(self,
            obs: Union[dict, Batch, np.ndarray],
            act: Union[np.ndarray, float],
            rew: Union[int, float],
            done: bool,
            obs_next: Optional[Union[dict, Batch, np.ndarray]] = None,
            info: dict = {},
            policy: Optional[Union[dict, Batch]] = {},
            **kwargs) -> None:
        """Add a batch of data into replay buffer."""
        assert isinstance(info, (dict, Batch)), \
            'You should return a dict in the last argument of env.step().'
        self._add_to_buffer('obs', obs)
        self._add_to_buffer('act', act)
        self._add_to_buffer('rew', rew)
        self._add_to_buffer('done', done)
        self._add_to_buffer('done_bk', kwargs.get('done_bk', done))
        if self._save_s_:
            if obs_next is None:
                obs_next = Batch()
            self._add_to_buffer('obs_next', obs_next)
        self._add_to_buffer('info', info)
        self._add_to_buffer('policy', policy)
        if self._ens_num is not None:
            mask = np.random.randint(2, size=self._ens_num)
            self._add_to_buffer('mask', mask)
        if self._ngu:
            self._add_to_buffer('erew', 0)
            self._add_to_buffer('goal', 0)
        if self._rand2:
            self._add_to_buffer('irew', 0)
        if 'log_prob' in kwargs:
            self._add_to_buffer('log_prob', kwargs.get('log_prob', 0))
        if 'alpha' in kwargs:
            self._add_to_buffer('alpha', kwargs.get('alpha', 0))
        if 'int_rew' in kwargs:
            self._add_to_buffer('int_rew', kwargs.get('int_rew', 0))

        # maintain available index for frame-stack sampling
        if self._avail:
            # update current frame
            avail = sum(self.done[i]
                        for i in range(self._index - self._stack +
                                       1, self._index)) == 0
            if self._size < self._stack - 1:
                avail = False
            if avail and self._index not in self._avail_index:
                self._avail_index.append(self._index)
            elif not avail and self._index in self._avail_index:
                self._avail_index.remove(self._index)
            # remove the later available frame because of broken storage
            t = (self._index + self._stack - 1) % self._maxsize
            if t in self._avail_index:
                self._avail_index.remove(t)

        if self._maxsize > 0:
            self._size = min(self._size + 1, self._maxsize)
            self._index = (self._index + 1) % self._maxsize
        else:
            self._size = self._index = self._index + 1
Exemple #8
0
 def __getitem__(self, index):
     """Return a data batch: self[index]."""
     return Batch(obs=self.obs[index],
                  act=self.act[index],
                  rew=self.rew[index],
                  done=self.done[index],
                  obs_next=self.obs_next[index],
                  info=self.info[index])
Exemple #9
0
 def __getitem__(self, index):
     return Batch(obs=self.get(index, 'obs'),
                  act=self.act[index],
                  rew=self.rew[index],
                  done=self.done[index],
                  obs_next=self.get(index, 'obs_next'),
                  info=self.info[index],
                  weight=self.weight[index])
Exemple #10
0
 def get(self, indice, key, stack_num=None):
     """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
     where s is self.key, t is indice. The stack_num (here equals to 4) is
     given from buffer initialization procedure.
     """
     if stack_num is None:
         stack_num = self._stack
     if not isinstance(indice, np.ndarray):
         if np.isscalar(indice):
             indice = np.array(indice)
         elif isinstance(indice, slice):
             indice = np.arange(
                 0 if indice.start is None else self._size -
                 indice.start if indice.start < 0 else indice.start,
                 self._size if indice.stop is None else self._size -
                 indice.stop if indice.stop < 0 else indice.stop,
                 1 if indice.step is None else indice.step)
     # set last frame done to True
     last_index = (self._index - 1 + self._size) % self._size
     last_done, self.done[last_index] = self.done[last_index], True
     if key == 'obs_next' and not self._save_s_:
         indice += 1 - self.done[indice].astype(np.int)
         indice[indice == self._size] = 0
         key = 'obs'
     if stack_num == 0:
         self.done[last_index] = last_done
         if key in self._meta:
             return {
                 k: self.__dict__['_' + key + '@' + k][indice]
                 for k in self._meta[key]
             }
         else:
             return self.__dict__[key][indice]
     if key in self._meta:
         many_keys = self._meta[key]
         stack = {k: [] for k in self._meta[key]}
     else:
         stack = []
         many_keys = None
     for i in range(stack_num):
         if many_keys is not None:
             for k_ in many_keys:
                 k__ = '_' + key + '@' + k_
                 stack[k_] = [self.__dict__[k__][indice]] + stack[k_]
         else:
             stack = [self.__dict__[key][indice]] + stack
         pre_indice = indice - 1
         pre_indice[pre_indice == -1] = self._size - 1
         indice = pre_indice + self.done[pre_indice].astype(np.int)
         indice[indice == self._size] = 0
     self.done[last_index] = last_done
     if many_keys is not None:
         for k in stack:
             stack[k] = np.stack(stack[k], axis=1)
         stack = Batch(**stack)
     else:
         stack = np.stack(stack, axis=1)
     return stack
Exemple #11
0
 def __getitem__(self, index):
     """Return a data batch: self[index]. If stack_num is set to be > 0,
     return the stacked obs and obs_next with shape [batch, len, ...].
     """
     return Batch(obs=self.get(index, 'obs'),
                  act=self.act[index],
                  rew=self.rew[index],
                  done=self.done[index],
                  obs_next=self.get(index, 'obs_next'),
                  info=self.info[index])
Exemple #12
0
 def __getattr__(self, key: str) -> Union[Batch, np.ndarray]:
     """Return self.key"""
     if key not in self._meta:
         if key not in self.__dict__:
             raise AttributeError(key)
         return self.__dict__[key]
     d = {}
     for k_ in self._meta[key]:
         k__ = '_' + key + '@' + k_
         d[k_] = self.__dict__[k__]
     return Batch(**d)
Exemple #13
0
 def __getitem__(self, index):
     return Batch(
         obs=self.get(index, 'obs'),
         act=self.act[index],
         # act_=self.get(index, 'act'),  # stacked action, for RNN
         rew=self.rew[index],
         done=self.done[index],
         obs_next=self.get(index, 'obs_next'),
         info=self.info[index],
         weight=self.weight[index],
         policy=self.get(index, 'policy'),
     )
Exemple #14
0
 def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
     return Batch(
         obs=self.get(index, 'obs'),
         act=self.act[index],
         # act_=self.get(index, 'act'),  # stacked action, for RNN
         rew=self.rew[index],
         done=self.done[index],
         obs_next=self.get(index, 'obs_next'),
         info=self.get(index, 'info'),
         weight=self.weight[index],
         policy=self.get(index, 'policy'),
     )
 def __getitem__(self, index: Union[slice, int, np.integer,
                                    np.ndarray]) -> Batch:
     return Batch(
         obs=self.get(index, 'obs'),
         act=self.act[index],
         rew=self.rew[index],
         done=self.done[index],
         obs_next=self.get(index, 'obs_next'),
         info=self.get(index, 'info'),
         weight=self.weight[index],
         policy=self.get(index, 'policy'),
     )
Exemple #16
0
 def __getitem__(self, index: Union[slice, int, np.integer,
                                    np.ndarray]) -> Batch:
     """Return a data batch: self[index]. If stack_num is set to be > 0,
     return the stacked obs and obs_next with shape [batch, len, ...].
     """
     return Batch(obs=self.get(index, 'obs'),
                  act=self.act[index],
                  rew=self.rew[index],
                  done=self.done[index],
                  obs_next=self.get(index, 'obs_next'),
                  info=self.get(index, 'info'),
                  policy=self.get(index, 'policy'))
Exemple #17
0
 def __init__(self, size: int, stack_num: Optional[int] = 0,
              ignore_obs_next: bool = False,
              sample_avail: bool = False, **kwargs) -> None:
     super().__init__()
     self._maxsize = size
     self._stack = stack_num
     assert stack_num != 1, 'stack_num should greater than 1'
     self._avail = sample_avail and stack_num > 1
     self._avail_index = []
     self._save_s_ = not ignore_obs_next
     self._index = 0
     self._size = 0
     self._meta = Batch()
     self.reset()
Exemple #18
0
 def sample(self, batch_size):
     if batch_size > 0:
         indice = np.random.choice(self._size, batch_size)
     else:
         indice = np.concatenate([
             np.arange(self._index, self._size),
             np.arange(0, self._index),
         ])
     return Batch(obs=self.obs[indice],
                  act=self.act[indice],
                  rew=self.rew[indice],
                  done=self.done[indice],
                  obs_next=self.obs_next[indice],
                  info=self.info[indice]), indice
 def get(self,
         indice: Union[slice, int, np.integer, np.ndarray],
         key: str,
         stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]:
     """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
     where s is self.key, t is indice. The stack_num (here equals to 4) is
     given from buffer initialization procedure.
     """
     if stack_num is None:
         stack_num = self._stack
     if isinstance(indice, slice):
         indice = np.arange(
             0 if indice.start is None else self._size -
             indice.start if indice.start < 0 else indice.start,
             self._size if indice.stop is None else self._size -
             indice.stop if indice.stop < 0 else indice.stop,
             1 if indice.step is None else indice.step)
     else:
         indice = np.array(indice, copy=True)
     # set last frame done to True
     last_index = (self._index - 1 + self._size) % self._size
     last_done, self.done[last_index] = self.done[last_index], True
     if key == 'obs_next' and (not self._save_s_ or self.obs_next is None):
         indice += 1 - self.done[indice].astype(np.int)
         indice[indice == self._size] = 0
         key = 'obs'
     val = self._meta.__dict__[key]
     try:
         if stack_num > 0:
             stack = []
             for _ in range(stack_num):
                 stack = [val[indice]] + stack
                 pre_indice = np.asarray(indice - 1)
                 pre_indice[pre_indice == -1] = self._size - 1
                 indice = np.asarray(pre_indice +
                                     self.done[pre_indice].astype(np.int))
                 indice[indice == self._size] = 0
             if isinstance(val, Batch):
                 stack = Batch.stack(stack, axis=indice.ndim)
             else:
                 stack = np.stack(stack, axis=indice.ndim)
         else:
             stack = val[indice]
     except IndexError as e:
         stack = Batch()
         if not isinstance(val, Batch) or len(val.__dict__) > 0:
             raise e
     self.done[last_index] = last_done
     return stack
Exemple #20
0
def from_hdf5(x: h5py.Group, device: Optional[str] = None) -> Hdf5ConvertibleValues:
    """Restore object from HDF5 group."""
    if isinstance(x, h5py.Dataset):
        # handle datasets
        if x.attrs["__data_type__"] == "ndarray":
            return np.array(x)
        elif x.attrs["__data_type__"] == "Tensor":
            return torch.tensor(x, device=device)
        else:
            return pickle.loads(x[()])
    else:
        # handle groups representing a dict or a Batch
        y = dict(x.attrs.items())
        data_type = y.pop("__data_type__", None)
        for k, v in x.items():
            y[k] = from_hdf5(v, device)
        return Batch(y) if data_type == "Batch" else y
Exemple #21
0
def to_numpy(x: Any) -> Union[Batch, np.ndarray]:
    """Return an object without torch.Tensor."""
    if isinstance(x, torch.Tensor):  # most often case
        return x.detach().cpu().numpy()
    elif isinstance(x, np.ndarray):  # second often case
        return x
    elif isinstance(x, (np.number, np.bool_, Number)):
        return np.asanyarray(x)
    elif x is None:
        return np.array(None, dtype=object)
    elif isinstance(x, (dict, Batch)):
        x = Batch(x) if isinstance(x, dict) else deepcopy(x)
        x.to_numpy()
        return x
    elif isinstance(x, (list, tuple)):
        return to_numpy(_parse_value(x))
    else:  # fallback
        return np.asanyarray(x)
    def __getitem__(self, index: Union[slice, int, np.integer,
                                       np.ndarray]) -> Batch:
        """Return a data batch: self[index]. If stack_num is set to be > 0,
        return the stacked obs and obs_next with shape [batch, len, ...].
        """
        kwargs = dict()
        for i in range(1, self.num):
            kwargs[''.join(['act'] + ['_next'] * i)] = self.act[(index + i) %
                                                                self._size]
            kwargs[''.join(['rew'] + ['_next'] * i)] = self.rew[(index + i) %
                                                                self._size]
            kwargs[''.join(['obs_next'] + ['_next'] * i)] = self.get(
                (index + i) % self._size, 'obs_next')
            kwargs[''.join(['done'] + ['_next'] * i)] = self.done[(index + i) %
                                                                  self._size]
            kwargs[''.join(['done_bk'] + ['_next'] *
                           (i - 1))] = self.done_bk[(index + i - 1) %
                                                    self._size]

        ret = Batch(obs=self.get(index, 'obs'),
                    obs_next=self.get(index, 'obs_next'),
                    act=self.act[index],
                    rew=self.rew[index],
                    done=self.done[index],
                    info=self.get(index, 'info'),
                    policy=self.get(index, 'policy'),
                    **kwargs)

        if self.log_prob is not None:
            ret.log_prob = self.log_prob[index]
        if self.alpha is not None:
            ret.alpha = self.alpha[index]
        # if hasattr(self, 'int_rew'):
        #     ret.int_rew = self.int_rew[index]

        if self._ens_num is not None:
            ret.mask = self.mask[index]
        if self._ngu:
            ret.erew = self.erew[index]
            ret.goal = self.goal[index]
        if self._rand2:
            ret.irew = self.irew[index]
        return ret
 def __init__(self,
              size: int,
              stack_num: Optional[int] = 0,
              ignore_obs_next: bool = False,
              sample_avail: bool = False,
              **kwargs) -> None:
     super().__init__()
     self._maxsize = size
     self._stack = stack_num
     assert stack_num != 1, 'stack_num should greater than 1'
     self._avail = sample_avail and stack_num > 1
     self._avail_index = []
     self._save_s_ = not ignore_obs_next
     self._index = 0
     self._size = 0
     self._meta = Batch()
     self._max_ep_len = kwargs.get(
         'max_ep_len', None
     )  # used to identify whether the end is terminated by time limit.
     self._ens_num = kwargs.get('ens_num', None)  # if not none, add mask
     self._ngu = kwargs.get('ngu', False)
     self._rand2 = kwargs.get('rand2', False)
     self.non_episodic = kwargs.get('non_episodic', False)
     self.reset()