def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                   indice: np.ndarray) -> Batch:
        if self._rm_done:
            batch.done = batch.done * 0.

        if self._rew_norm:
            if self.norm_func is None:
                bfr = buffer.rew[:min(len(buffer), 1000)]  # avoid large buffer
                mean, std = bfr.mean(), bfr.std()
                if np.isclose(std, 0):
                    mean, std = 0, 1
                batch.rew = (batch.rew - mean) / std
            else:
                batch.rew = self.norm_func(batch.rew)
        return batch
Exemple #2
0
 def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                indice: np.ndarray) -> Batch:
     if self._rew_norm:
         mean, std = batch.rew.mean(), batch.rew.std()
         if not np.isclose(std, 0, 1e-2):
             batch.rew = (batch.rew - mean) / std
     v, v_, old_log_prob = [], [], []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False):
             v_.append(self.critic(b.obs_next))
             v.append(self.critic(b.obs))
             old_log_prob.append(self(b).dist.log_prob(
                 to_torch_as(b.act, v[0])))
     v_ = to_numpy(torch.cat(v_, dim=0))
     batch = self.compute_episodic_return(
         batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
         rew_norm=self._rew_norm)
     batch.v = torch.cat(v, dim=0).flatten()  # old value
     batch.act = to_torch_as(batch.act, v[0])
     batch.logp_old = torch.cat(old_log_prob, dim=0)
     batch.returns = to_torch_as(batch.returns, v[0])
     batch.adv = batch.returns - batch.v
     if self._rew_norm:
         mean, std = batch.adv.mean(), batch.adv.std()
         if not np.isclose(std.item(), 0, 1e-2):
             batch.adv = (batch.adv - mean) / std
     return batch
Exemple #3
0
    def post_process_fn(self, batch: Batch, buffer: ReplayBuffer,
                        indices: np.ndarray) -> None:
        """Post-process the data from the provided replay buffer.

        Typical usage is to update the sampling weight in prioritized
        experience replay. Used in :meth:`update`.
        """
        self.policy.post_process_fn(batch, buffer, indices)
        batch.rew = batch.policy.orig_rew  # restore original reward
Exemple #4
0
    def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                   indices: np.ndarray) -> Batch:
        """Pre-process the data from the provided replay buffer.

        Used in :meth:`update`. Check out :ref:`process_fn` for more information.
        """
        # update reward
        with torch.no_grad():
            batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten())
        return super().process_fn(batch, buffer, indices)
Exemple #5
0
    def process_fn(self, batch: Batch, buffer=None,
                   indice=None) -> Batch:
        v_ = None
        rew,v_ = self.get_reward_by_discriminator(batch)

        batch.rew = rew 

        if self._rew_norm:
            mean, std = batch.rew.mean(), batch.rew.std()
            if not np.isclose(std.cpu().numpy(), 0):
                batch.rew = (batch.rew - mean) / std
        if self._lambda in [0, 1]:
            return self.compute_episodic_return(
                batch, None, gamma=self._gamma, gae_lambda=self._lambda)
        batch.to_numpy()
        batch = self.compute_episodic_return(
            batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
        batch.to_torch()
        return batch
Exemple #6
0
 def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                indice: np.ndarray) -> Batch:
     if self._rew_norm:
         bfr = buffer.rew[:min(len(buffer), 1000)]  # avoid large buffer
         mean, std = bfr.mean(), bfr.std()
         if std > self.__eps:
             batch.rew = (batch.rew - mean) / std
     if self._rm_done:
         batch.done = batch.done * 0.
     return batch
Exemple #7
0
    def add(
        self,
        batch: Batch,
        buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Add a batch of data into replay buffer.

        :param Batch batch: the input data batch. Its keys must belong to the 7
            reserved keys, and "obs", "act", "rew", "done" is required.
        :param buffer_ids: to make consistent with other buffer's add function; if it
            is not None, we assume the input batch's first dimension is always 1.

        Return (current_index, episode_reward, episode_length, episode_start_index). If
        the episode is not finished, the return value of episode_length and
        episode_reward is 0.
        """
        # preprocess batch
        b = Batch()
        for key in set(self._reserved_keys).intersection(batch.keys()):
            b.__dict__[key] = batch[key]
        batch = b
        assert set(["obs", "act", "rew", "done"]).issubset(batch.keys())
        stacked_batch = buffer_ids is not None
        if stacked_batch:
            assert len(batch) == 1
        if self._save_only_last_obs:
            batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1]
        if not self._save_obs_next:
            batch.pop("obs_next", None)
        elif self._save_only_last_obs:
            batch.obs_next = (
                batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1]
            )
        # get ptr
        if stacked_batch:
            rew, done = batch.rew[0], batch.done[0]
        else:
            rew, done = batch.rew, batch.done
        ptr, ep_rew, ep_len, ep_idx = list(
            map(lambda x: np.array([x]), self._add_index(rew, done))
        )
        try:
            self._meta[ptr] = batch
        except ValueError:
            stack = not stacked_batch
            batch.rew = batch.rew.astype(float)
            batch.done = batch.done.astype(bool)
            if self._meta.is_empty():
                self._meta = _create_value(  # type: ignore
                    batch, self.maxsize, stack)
            else:  # dynamic key pops up in batch
                _alloc_by_keys_diff(self._meta, batch, self.maxsize, stack)
            self._meta[ptr] = batch
        return ptr, ep_rew, ep_len, ep_idx
Exemple #8
0
    def add(
        self,
        batch: Batch,
        buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Add a batch of data into ReplayBufferManager.

        Each of the data's length (first dimension) must equal to the length of
        buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1].

        Return (current_index, episode_reward, episode_length, episode_start_index). If
        the episode is not finished, the return value of episode_length and
        episode_reward is 0.
        """
        # preprocess batch
        new_batch = Batch()
        for key in set(self._reserved_keys).intersection(batch.keys()):
            new_batch.__dict__[key] = batch[key]
        batch = new_batch
        assert set(["obs", "act", "rew", "done"]).issubset(batch.keys())
        if self._save_only_last_obs:
            batch.obs = batch.obs[:, -1]
        if not self._save_obs_next:
            batch.pop("obs_next", None)
        elif self._save_only_last_obs:
            batch.obs_next = batch.obs_next[:, -1]
        # get index
        if buffer_ids is None:
            buffer_ids = np.arange(self.buffer_num)
        ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], []
        for batch_idx, buffer_id in enumerate(buffer_ids):
            ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index(
                batch.rew[batch_idx], batch.done[batch_idx]
            )
            ptrs.append(ptr + self._offset[buffer_id])
            ep_lens.append(ep_len)
            ep_rews.append(ep_rew)
            ep_idxs.append(ep_idx + self._offset[buffer_id])
            self.last_index[buffer_id] = ptr + self._offset[buffer_id]
            self._lengths[buffer_id] = len(self.buffers[buffer_id])
        ptrs = np.array(ptrs)
        try:
            self._meta[ptrs] = batch
        except ValueError:
            batch.rew = batch.rew.astype(float)
            batch.done = batch.done.astype(bool)
            if self._meta.is_empty():
                self._meta = _create_value(  # type: ignore
                    batch, self.maxsize, stack=False)
            else:  # dynamic key pops up in batch
                _alloc_by_keys_diff(self._meta, batch, self.maxsize, False)
            self._set_batch_for_children()
            self._meta[ptrs] = batch
        return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)
Exemple #9
0
 def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                indice: np.ndarray) -> Batch:
     if self._rew_norm:
         mean, std = batch.rew.mean(), batch.rew.std()
         if std > self.__eps:
             batch.rew = (batch.rew - mean) / std
     if self._lambda in [0, 1]:
         return self.compute_episodic_return(
             batch, None, gamma=self._gamma, gae_lambda=self._lambda)
     v_ = []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False):
             v_.append(self.critic(b.obs_next))
     v_ = torch.cat(v_, dim=0).cpu().numpy()
     return self.compute_episodic_return(
         batch, v_, gamma=self._gamma, gae_lambda=self._lambda)