Exemple #1
0
 def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
         stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]:
     """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
     where s is self.key, t is indice. The stack_num (here equals to 4) is
     given from buffer initialization procedure.
     """
     if stack_num is None:
         stack_num = self.stack_num
     if stack_num == 1:  # the most often case
         if key != 'obs_next' or self._save_s_:
             val = self._meta.__dict__[key]
             try:
                 return val[indice]
             except IndexError as e:
                 if not (isinstance(val, Batch) and val.is_empty()):
                     raise e  # val != Batch()
                 return Batch()
     indice = self._indices[:self._size][indice]
     done = self._meta.__dict__['done']
     if key == 'obs_next' and not self._save_s_:
         indice += 1 - done[indice].astype(np.int)
         indice[indice == self._size] = 0
         key = 'obs'
     val = self._meta.__dict__[key]
     try:
         if stack_num == 1:
             return val[indice]
         stack = []
         for _ in range(stack_num):
             stack = [val[indice]] + stack
             pre_indice = np.asarray(indice - 1)
             pre_indice[pre_indice == -1] = self._size - 1
             indice = np.asarray(
                 pre_indice + done[pre_indice].astype(np.int))
             indice[indice == self._size] = 0
         if isinstance(val, Batch):
             stack = Batch.stack(stack, axis=indice.ndim)
         else:
             stack = np.stack(stack, axis=indice.ndim)
         return stack
     except IndexError as e:
         if not (isinstance(val, Batch) and val.is_empty()):
             raise e  # val != Batch()
         return Batch()
Exemple #2
0
def test_async_env(size=10000, num=8, sleep=0.1):
    # simplify the test case, just keep stepping
    env_fns = [
        lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
        for i in range(size, size + num)
    ]
    test_cls = [SubprocVectorEnv, ShmemVectorEnv]
    if has_ray():
        test_cls += [RayVectorEnv]
    for cls in test_cls:
        v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
        v.seed(None)
        v.reset()
        # for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un}
        # P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1}
        # expectation of v is n / (n + 1)
        # for a synchronous environment, the following actions should take
        # about 7 * sleep * num / (num + 1) seconds
        # for async simulation, the analysis is complicated, but the time cost
        # should be smaller
        action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
        current_idx_start = 0
        action = action_list[:num]
        env_ids = list(range(num))
        o = []
        spent_time = time.time()
        while current_idx_start < len(action_list):
            A, B, C, D = v.step(action=action, id=env_ids)
            b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
            env_ids = b.info.env_id
            o.append(b)
            current_idx_start += len(action)
            # len of action may be smaller than len(A) in the end
            action = action_list[current_idx_start:current_idx_start + len(A)]
            # truncate env_ids with the first terms
            # typically len(env_ids) == len(A) == len(action), except for the
            # last batch when actions are not enough
            env_ids = env_ids[:len(action)]
        spent_time = time.time() - spent_time
        Batch.cat(o)
        v.close()
        # assure 1/7 improvement
        if sys.platform != "darwin":  # macOS cannot pass this check
            assert spent_time < 6.0 * sleep * num / (num + 1)
Exemple #3
0
    def forward(
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        model: str = "model",
        input: str = "obs",
        **kwargs: Any,
    ) -> Batch:
        """Compute action over the given batch data.

        If you need to mask the action, please add a "mask" into batch.obs, for
        example, if we have an environment that has "0/1/2" three actions:
        ::

            batch == Batch(
                obs=Batch(
                    obs="original obs, with batch_size=1 for demonstration",
                    mask=np.array([[False, True, False]]),
                    # action 1 is available
                    # action 0 and 2 are unavailable
                ),
                ...
            )

        :param float eps: in [0, 1], for epsilon-greedy exploration method.

        :return: A :class:`~tianshou.data.Batch` which has 3 keys:

            * ``act`` the action.
            * ``logits`` the network's raw output.
            * ``state`` the hidden state.

        .. seealso::

            Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
            more detailed explanation.
        """
        model = getattr(self, model)
        obs = batch[input]
        obs_ = obs.obs if hasattr(obs, "obs") else obs
        q, h = model(obs_, state=state, info=batch.info)
        act: np.ndarray = to_numpy(q.max(dim=1)[1])
        if hasattr(obs, "mask"):
            # some of actions are masked, they cannot be selected
            q_: np.ndarray = to_numpy(q)
            q_[~obs.mask] = -np.inf
            act = q_.argmax(axis=1)
        # add eps to act in training or testing phase
        if not self.updating and not np.isclose(self.eps, 0.0):
            for i in range(len(q)):
                if np.random.rand() < self.eps:
                    q_ = np.random.rand(*q[i].shape)
                    if hasattr(obs, "mask"):
                        q_[~obs.mask[i]] = -np.inf
                    act[i] = q_.argmax()
        return Batch(logits=q, act=act, state=h)
Exemple #4
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     if self._target and self._cnt % self._freq == 0:
         self.sync_weight()
     weight = batch.pop("weight", 1.0)
     self.optim.zero_grad()
     q = self(batch, eps=0.).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns, q).flatten()
     c = torch.nn.SmoothL1Loss(reduction = 'none')
     # c = lambda r, q: (r-q).pow(2)
     td = c(r, q)
     loss = (td * weight).mean()
     batch.weight = loss  # prio-buffer
     loss.backward()
     if self.grad_norm_clipping:
         torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm_clipping)
     self.optim.step()
     self._cnt += 1
     return {'loss': loss.item()}
Exemple #5
0
 def reset(self) -> None:
     """Reset all related variables in the collector."""
     # use empty Batch for ``state`` so that ``self.data`` supports slicing
     # convert empty Batch to None when passing data to policy
     self.data = Batch(state={},
                       obs={},
                       act={},
                       rew={},
                       done={},
                       info={},
                       obs_next={},
                       policy={})
     # print("before reset env")
     self.reset_env()
     # print("after reset env")
     self.reset_buffer()
     self.reset_stat()
     if self._action_noise is not None:
         self._action_noise.reset()
Exemple #6
0
 def forward(self,
             batch: Batch,
             state: Optional[Union[dict, Batch, np.ndarray]] = None,
             **kwargs) -> Batch:
     logits, h = self.model(batch.obs, state=state, info=batch.info)
     if self.mode == 'discrete':
         a = logits.max(dim=1)[1]
     else:
         a = logits
     return Batch(logits=logits, act=a, state=h)
Exemple #7
0
 def __init__(self,
              size: int,
              stack_num: int = 1,
              ignore_obs_next: bool = False,
              save_only_last_obs: bool = False,
              sample_avail: bool = False) -> None:
     super().__init__()
     self._maxsize = size
     self._indices = np.arange(size)
     self._stack = None
     self.stack_num = stack_num
     self._avail = sample_avail and stack_num > 1
     self._avail_index = []
     self._save_s_ = not ignore_obs_next
     self._last_obs = save_only_last_obs
     self._index = 0
     self._size = 0
     self._meta = Batch()
     self.reset()
Exemple #8
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._iter % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device)
     q = self(batch).logits
     act_mask = torch.zeros_like(q)
     act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1)
     act_q = q * act_mask
     returns = batch.returns
     returns = returns * act_mask
     td_error = returns - act_q
     loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean()
     batch.weight = td_error.sum(-1).sum(-1)  # prio-buffer
     loss.backward()
     self.optim.step()
     self._iter += 1
     return {"loss": loss.item()}
Exemple #9
0
    def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                   indices: np.ndarray) -> Batch:
        """Pre-process the data from the provided replay buffer.

        Used in :meth:`update`. Check out :ref:`process_fn` for more information.
        """
        # update reward
        with torch.no_grad():
            batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten())
        return super().process_fn(batch, buffer, indices)
Exemple #10
0
    def reset(self, reset_buffer: bool = True) -> None:
        """Reset the environment, statistics, current data and possibly replay memory.

        :param bool reset_buffer: if true, reset the replay buffer that is attached
            to the collector.
        """
        # use empty Batch for "state" so that self.data supports slicing
        # convert empty Batch to None when passing data to policy
        self.data = Batch(obs={},
                          act={},
                          rew={},
                          done={},
                          obs_next={},
                          info={},
                          policy={})
        self.reset_env()
        if reset_buffer:
            self.reset_buffer()
        self.reset_stat()
Exemple #11
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     if self._target and self._cnt % self._freq == 0:
         self.sync_weight()
     self.optim.zero_grad()
     weight = batch.pop("weight", 1.0)
     q = self(batch).logits
     q = q[np.arange(len(q)), batch.act]
     r = to_torch_as(batch.returns.flatten(), q)
     td = r - q
     loss = (td.pow(2) * weight).mean()
     batch.weight = td  # prio-buffer
     loss.backward()
     # Gradient clips
     torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
     self.optim.step()
     self._cnt += 1
     for param_group in self.optim.param_groups:
         lr = param_group['lr']
     return {"loss": loss.item(), "lr": lr}
Exemple #12
0
    def add(  # type: ignore
        self,
        obs: Any,
        act: Any,
        rew: np.ndarray,
        done: np.ndarray,
        obs_next: Any = Batch(),
        info: Optional[Batch] = Batch(),
        policy: Optional[Batch] = Batch(),
        cached_buffer_ids: Optional[Union[np.ndarray, List[int]]] = None,
        **kwargs: Any,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Add a batch of data into CachedReplayBuffer.

        Each of the data's length (first dimension) must equal to the length of
        cached_buffer_ids. By default the cached_buffer_ids is [0, 1, ...,
        cached_buffer_num - 1].

        Return the array of episode_length and episode_reward with shape
        (len(cached_buffer_ids), ...), where (episode_length[i],
        episode_reward[i]) refers to the cached_buffer_ids[i]th cached buffer's
        corresponding episode result.
        """
        if cached_buffer_ids is None:
            cached_buffer_ids = np.arange(self.cached_buffer_num)
        else:  # make sure it is np.ndarray
            cached_buffer_ids = np.asarray(cached_buffer_ids)
        # in self.buffers, the first buffer is main_buffer
        buffer_ids = cached_buffer_ids + 1  # type: ignore
        result = super().add(obs,
                             act,
                             rew,
                             done,
                             obs_next,
                             info,
                             policy,
                             buffer_ids=buffer_ids,
                             **kwargs)
        # find the terminated episode, move data from cached buf to main buf
        for buffer_idx in cached_buffer_ids[np.asarray(done, np.bool_)]:
            self.main_buffer.update(self.cached_buffers[buffer_idx])
            self.cached_buffers[buffer_idx].reset()
        return result
Exemple #13
0
    def get(
        self,
        index: Union[int, List[int], np.ndarray],
        key: str,
        default_value: Any = None,
        stack_num: Optional[int] = None,
    ) -> Union[Batch, np.ndarray]:
        """Return the stacked result.

        E.g., if you set ``key = "obs", stack_num = 4, index = t``, it returns the
        stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``.

        :param index: the index for getting stacked data.
        :param str key: the key to get, should be one of the reserved_keys.
        :param default_value: if the given key's data is not found and default_value is
            set, return this default_value.
        :param int stack_num: Default to self.stack_num.
        """
        if key not in self._meta and default_value is not None:
            return default_value
        val = self._meta[key]
        if stack_num is None:
            stack_num = self.stack_num
        try:
            if stack_num == 1:  # the most often case
                return val[index]
            stack: List[Any] = []
            if isinstance(index, list):
                indices = np.array(index)
            else:
                indices = index  # type: ignore
            for _ in range(stack_num):
                stack = [val[indices]] + stack
                indices = self.prev(indices)
            if isinstance(val, Batch):
                return Batch.stack(stack, axis=indices.ndim)
            else:
                return np.stack(stack, axis=indices.ndim)
        except IndexError as e:
            if not (isinstance(val, Batch) and val.is_empty()):
                raise e  # val != Batch()
            return Batch()
Exemple #14
0
 def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                indice: np.ndarray) -> Batch:
     if self._rew_norm:
         mean, std = batch.rew.mean(), batch.rew.std()
         if std > self.__eps:
             batch.rew = (batch.rew - mean) / std
     if self._lambda in [0, 1]:
         return self.compute_episodic_return(batch,
                                             None,
                                             gamma=self._gamma,
                                             gae_lambda=self._lambda)
     v_ = []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False):
             v_.append(self.critic(b.obs_next))
     v_ = torch.cat(v_, dim=0).cpu().numpy()
     return self.compute_episodic_return(batch,
                                         v_,
                                         gamma=self._gamma,
                                         gae_lambda=self._lambda)
Exemple #15
0
 def process_fn(
     self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
 ) -> Batch:
     v_s_ = []
     with torch.no_grad():
         for b in batch.split(self._batch, shuffle=False, merge_last=True):
             v_s_.append(to_numpy(self.critic(b.obs_next)))
     v_s_ = np.concatenate(v_s_, axis=0)
     if self._rew_norm:  # unnormalize v_s_
         v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean
     unnormalized_returns, _ = self.compute_episodic_return(
         batch, buffer, indice, v_s_=v_s_,
         gamma=self._gamma, gae_lambda=self._lambda)
     if self._rew_norm:
         batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
             np.sqrt(self.ret_rms.var + self._eps)
         self.ret_rms.update(unnormalized_returns)
     else:
         batch.returns = unnormalized_returns
     return batch
Exemple #16
0
 def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
     weight = batch.pop('weight', 1.)
     current_q = self.critic(batch.obs, batch.act).flatten()
     target_q = batch.returns.flatten()
     td = current_q - target_q
     critic_loss = (td.pow(2) * weight).mean()
     batch.weight = td  # prio-buffer
     self.critic_optim.zero_grad()
     critic_loss.backward()
     self.critic_optim.step()
     action = self(batch, explorating=False).act
     actor_loss = -self.critic(batch.obs, action).mean()
     self.actor_optim.zero_grad()
     actor_loss.backward()
     self.actor_optim.step()
     self.sync_weight()
     return {
         'loss/actor': actor_loss.item(),
         'loss/critic': critic_loss.item(),
     }
 def forward(self, inputs, last_state=None, deterministic=False):
     obs = inputs.obs
     if not isinstance(obs, torch.Tensor):
         obs = torch.tensor(obs, device=self.device, dtype=torch.float32)
     value, actor_features, self.rnn_hxs = self.base(obs, self.rnn_hxs, self.masks)
     dist = self.dist(actor_features)
     if deterministic:
         action = dist.mode()
     else:
         action = dist.sample()
     return Batch(logits=dist.logits, act=action[0], state=None, dist=dist)
Exemple #18
0
def test_batch():
    batch = Batch(obs=[0], np=np.zeros([3, 4]))
    batch.update(obs=[1])
    assert batch.obs == [1]
    batch.append(batch)
    assert batch.obs == [1, 1]
    assert batch.np.shape == (6, 4)
    assert batch[0].obs == batch[1].obs
    with pytest.raises(IndexError):
        batch[2]
    batch.obs = np.arange(5)
    for i, b in enumerate(batch.split(1, permute=False)):
        assert b.obs == batch[i].obs
Exemple #19
0
 def __call__(self, batch, state=None, model='actor'):
     model = getattr(self, model)
     logits, h = model(batch.obs, state=state, info=batch.info)
     if isinstance(logits, tuple):
         dist = self.dist_fn(*logits)
     else:
         dist = self.dist_fn(logits)
     act = dist.sample()
     if self._range:
         act = act.clamp(self._range[0], self._range[1])
     return Batch(logits=logits, act=act, state=h, dist=dist)
Exemple #20
0
 def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
     weight = batch.pop("weight", 1.0)
     current_q = self.critic(batch.obs, batch.act).flatten()
     target_q = batch.returns.flatten()
     td = current_q - target_q
     critic_loss = (td.pow(2) * weight).mean()
     batch.weight = td  # prio-buffer
     self.critic_optim.zero_grad()
     critic_loss.backward()
     self.critic_optim.step()
     action = self(batch).act
     actor_loss = -self.critic(batch.obs, action).mean()
     self.actor_optim.zero_grad()
     actor_loss.backward()
     self.actor_optim.step()
     self.sync_weight()
     return {
         "loss/actor": actor_loss.item(),
         "loss/critic": critic_loss.item(),
     }
Exemple #21
0
 def __init__(
     self,
     policy: BasePolicy,
     env: Union[gym.Env, BaseVectorEnv],
     buffer: Optional[ReplayBuffer] = None,
     preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
     action_noise: Optional[BaseNoise] = None,
     reward_metric: Optional[Callable[[np.ndarray], float]] = None,
 ) -> None:
     super().__init__()
     if not isinstance(env, BaseVectorEnv):
         env = DummyVectorEnv([lambda: env])
     self.env = env
     self.env_num = len(env)
     # environments that are available in step()
     # this means all environments in synchronous simulation
     # but only a subset of environments in asynchronous simulation
     self._ready_env_ids = np.arange(self.env_num)
     # self.async is a flag to indicate whether this collector works
     # with asynchronous simulation
     self.is_async = env.is_async
     # need cache buffers before storing in the main buffer
     self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
     self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
     self.buffer = buffer
     self.policy = policy
     self.preprocess_fn = preprocess_fn
     self.process_fn = policy.process_fn
     self._action_space = env.action_space
     self._action_noise = action_noise
     self._rew_metric = reward_metric or Collector._default_rew_metric
     # avoid creating attribute outside __init__
     self.data = Batch(state={},
                       obs={},
                       act={},
                       rew={},
                       done={},
                       info={},
                       obs_next={},
                       policy={})
     self.reset()
Exemple #22
0
def test_fn(size=2560):
    policy = PGPolicy(None, None, None, discount_factor=0.1)
    buf = ReplayBuffer(100)
    buf.add(1, 1, 1, 1, 1)
    fn = policy.process_fn
    # fn = compute_return_base
    batch = Batch(
        done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
        rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
    )
    batch = fn(batch, buf, 0)
    ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
    assert abs(batch.returns - ans).sum() <= 1e-5
    batch = Batch(
        done=np.array([0, 1, 0, 1, 0, 1, 0.]),
        rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
    )
    batch = fn(batch, buf, 0)
    ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
    assert abs(batch.returns - ans).sum() <= 1e-5
    batch = Batch(
        done=np.array([0, 1, 0, 1, 0, 0, 1.]),
        rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
    )
    batch = fn(batch, buf, 0)
    ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
    assert abs(batch.returns - ans).sum() <= 1e-5
    if __name__ == '__main__':
        batch = Batch(
            done=np.random.randint(100, size=size) == 0,
            rew=np.random.random(size),
        )
        cnt = 3000
        t = time.time()
        for _ in range(cnt):
            compute_return_base(batch)
        print(f'vanilla: {(time.time() - t) / cnt}')
        t = time.time()
        for _ in range(cnt):
            policy.process_fn(batch, buf, 0)
        print(f'policy: {(time.time() - t) / cnt}')
Exemple #23
0
    def process_fn(self, batch: Batch, buffer: ReplayBuffer,
                   indice: np.ndarray) -> Batch:
        r"""Compute the n-step return for Q-learning targets:

        .. math::
            G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
            \gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a
            (Q_{new}(s_{t + n}, a)))

        , where :math:`\gamma` is the discount factor,
        :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
        :math:`t`. If there is no target network, the :math:`Q_{old}` is equal
        to :math:`Q_{new}`.
        """
        batch = self.compute_nstep_return(batch, buffer, indice,
                                          self._target_q, self._gamma,
                                          self._n_step)
        if isinstance(buffer, PrioritizedReplayBuffer):
            batch.update_weight = buffer.update_weight
            batch.indice = indice
        return batch
Exemple #24
0
 def forward(  # type: ignore
     self,
     batch: Batch,
     state: Optional[Union[dict, Batch, np.ndarray]] = None,
     input: str = "obs",
     **kwargs: Any,
 ) -> Batch:
     obs = batch[input]
     logits, h = self.actor(obs, state=state, info=batch.info)
     dist = Categorical(logits=logits)
     act = dist.sample()
     return Batch(logits=logits, act=act, state=h, dist=dist)
Exemple #25
0
def test_batch_from_to_numpy_without_copy():
    batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
    a_mem_addr_orig = batch.a.__array_interface__['data'][0]
    c_mem_addr_orig = batch.b.c.__array_interface__['data'][0]
    batch.to_torch()
    batch.to_numpy()
    a_mem_addr_new = batch.a.__array_interface__['data'][0]
    c_mem_addr_new = batch.b.c.__array_interface__['data'][0]
    assert a_mem_addr_new == a_mem_addr_orig
    assert c_mem_addr_new == c_mem_addr_orig
Exemple #26
0
    def learn(  # type: ignore
            self, batch: Batch, batch_size: int, repeat: int,
            **kwargs: Any) -> Dict[str, List[float]]:
        losses, clip_losses, vf_losses, ent_losses = [], [], [], []
        for _ in range(repeat):
            for b in batch.split(batch_size, merge_last=True):
                # calculate loss for actor
                dist = self(b).dist
                ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
                ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                surr1 = ratio * b.adv
                surr2 = ratio.clamp(1.0 - self._eps_clip,
                                    1.0 + self._eps_clip) * b.adv
                if self._dual_clip:
                    clip_loss = -torch.max(torch.min(surr1, surr2),
                                           self._dual_clip * b.adv).mean()
                else:
                    clip_loss = -torch.min(surr1, surr2).mean()
                # calculate loss for critic
                value = self.critic(b.obs).flatten()
                if self._value_clip:
                    v_clip = b.v_s + (value - b.v_s).clamp(
                        -self._eps_clip, self._eps_clip)
                    vf1 = (b.returns - value).pow(2)
                    vf2 = (b.returns - v_clip).pow(2)
                    vf_loss = 0.5 * torch.max(vf1, vf2).mean()
                else:
                    vf_loss = 0.5 * (b.returns - value).pow(2).mean()
                # calculate regularization and overall loss
                ent_loss = dist.entropy().mean()
                loss = clip_loss + self._weight_vf * vf_loss \
                    - self._weight_ent * ent_loss
                self.optim.zero_grad()
                loss.backward()
                if self._grad_norm is not None:  # clip large gradient
                    nn.utils.clip_grad_norm_(list(self.actor.parameters()) +
                                             list(self.critic.parameters()),
                                             max_norm=self._grad_norm)
                self.optim.step()
                clip_losses.append(clip_loss.item())
                vf_losses.append(vf_loss.item())
                ent_losses.append(ent_loss.item())
                losses.append(loss.item())
        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {
            "loss": losses,
            "loss/clip": clip_losses,
            "loss/vf": vf_losses,
            "loss/ent": ent_losses,
        }
Exemple #27
0
 def forward(
     self,
     batch: Batch,
     state: Optional[Union[dict, Batch, np.ndarray]] = None,
     **kwargs: Any,
 ) -> Batch:
     logits, hidden = self.model(batch.obs, state=state, info=batch.info)
     if self.action_type == "discrete":
         act = logits.max(dim=1)[1]
     else:
         act = logits
     return Batch(logits=logits, act=act, state=hidden)
Exemple #28
0
 def __getitem__(self, index: Union[slice, int, np.integer,
                                    np.ndarray]) -> Batch:
     return Batch(
         obs=self.get(index, 'obs'),
         act=self.act[index],
         rew=self.rew[index],
         done=self.done[index],
         obs_next=self.get(index, 'obs_next'),
         info=self.get(index, 'info'),
         policy=self.get(index, 'policy'),
         weight=self.weight[index],
     )
Exemple #29
0
    def add(
        self,
        obs: Any,
        act: Any,
        rew: Union[Number, np.number, np.ndarray],
        done: Union[Number, np.number, np.bool_],
        obs_next: Any = None,
        info: Optional[Union[dict, Batch]] = {},
        policy: Optional[Union[dict, Batch]] = {},
        **kwargs: Any,
    ) -> Tuple[int, Union[float, np.ndarray]]:
        """Add a batch of data into replay buffer.

        Return (episode_length, episode_reward) if one episode is terminated,
        otherwise return (0, 0.0).
        """
        assert isinstance(
            info,
            (dict, Batch
             )), "You should return a dict in the last argument of env.step()."
        if self._save_only_last_obs:
            obs = obs[-1]
        self._add_to_buffer("obs", obs)
        self._add_to_buffer("act", act)
        # make sure the data type of reward is float instead of int
        # but rew may be np.ndarray, so that we cannot use float(rew)
        rew = rew * 1.0  # type: ignore
        self._add_to_buffer("rew", rew)
        self._add_to_buffer("done", bool(done))  # done should be a bool scalar
        if self._save_obs_next:
            if obs_next is None:
                obs_next = Batch()
            elif self._save_only_last_obs:
                obs_next = obs_next[-1]
            self._add_to_buffer("obs_next", obs_next)
        self._add_to_buffer("info", info)
        self._add_to_buffer("policy", policy)

        if self.maxsize > 0:
            self._size = min(self._size + 1, self.maxsize)
            self._index = (self._index + 1) % self.maxsize
        else:  # TODO: remove this after deleting ListReplayBuffer
            self._size = self._index = self._size + 1

        self._episode_reward += rew
        self._episode_length += 1

        if done:
            result = self._episode_length, self._episode_reward
            self._episode_length, self._episode_reward = 0, 0.0
            return result
        else:
            return 0, self._episode_reward * 0.0
Exemple #30
0
def test_nstep_returns(size=10000):
    buf = ReplayBuffer(10)
    for i in range(12):
        buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3))
    batch, indices = buf.sample(0)
    assert np.allclose(indices, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
    # rew:  [11, 12, 3, 4, 5, 6, 7, 8, 9, 10]
    # done: [ 0,  1, 0, 1, 0, 0, 0, 1, 0, 0]
    # test nstep = 1
    returns = to_numpy(
        BasePolicy.compute_nstep_return(
            batch, buf, indices, target_q_fn, gamma=.1, n_step=1
        ).pop('returns').reshape(-1)
    )
    assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
    r_ = compute_nstep_return_base(1, .1, buf, indices)
    assert np.allclose(returns, r_), (r_, returns)
    returns_multidim = to_numpy(
        BasePolicy.compute_nstep_return(
            batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1
        ).pop('returns')
    )
    assert np.allclose(returns_multidim, returns[:, np.newaxis])
    # test nstep = 2
    returns = to_numpy(
        BasePolicy.compute_nstep_return(
            batch, buf, indices, target_q_fn, gamma=.1, n_step=2
        ).pop('returns').reshape(-1)
    )
    assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
    r_ = compute_nstep_return_base(2, .1, buf, indices)
    assert np.allclose(returns, r_)
    returns_multidim = to_numpy(
        BasePolicy.compute_nstep_return(
            batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2
        ).pop('returns')
    )
    assert np.allclose(returns_multidim, returns[:, np.newaxis])
    # test nstep = 10
    returns = to_numpy(
        BasePolicy.compute_nstep_return(
            batch, buf, indices, target_q_fn, gamma=.1, n_step=10
        ).pop('returns').reshape(-1)
    )
    assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12])
    r_ = compute_nstep_return_base(10, .1, buf, indices)
    assert np.allclose(returns, r_)
    returns_multidim = to_numpy(
        BasePolicy.compute_nstep_return(
            batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10
        ).pop('returns')
    )
    assert np.allclose(returns_multidim, returns[:, np.newaxis])