def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]: """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is indice. The stack_num (here equals to 4) is given from buffer initialization procedure. """ if stack_num is None: stack_num = self.stack_num if stack_num == 1: # the most often case if key != 'obs_next' or self._save_s_: val = self._meta.__dict__[key] try: return val[indice] except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() return Batch() indice = self._indices[:self._size][indice] done = self._meta.__dict__['done'] if key == 'obs_next' and not self._save_s_: indice += 1 - done[indice].astype(np.int) indice[indice == self._size] = 0 key = 'obs' val = self._meta.__dict__[key] try: if stack_num == 1: return val[indice] stack = [] for _ in range(stack_num): stack = [val[indice]] + stack pre_indice = np.asarray(indice - 1) pre_indice[pre_indice == -1] = self._size - 1 indice = np.asarray( pre_indice + done[pre_indice].astype(np.int)) indice[indice == self._size] = 0 if isinstance(val, Batch): stack = Batch.stack(stack, axis=indice.ndim) else: stack = np.stack(stack, axis=indice.ndim) return stack except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() return Batch()
def test_async_env(size=10000, num=8, sleep=0.1): # simplify the test case, just keep stepping env_fns = [ lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True) for i in range(size, size + num) ] test_cls = [SubprocVectorEnv, ShmemVectorEnv] if has_ray(): test_cls += [RayVectorEnv] for cls in test_cls: v = cls(env_fns, wait_num=num // 2, timeout=1e-3) v.seed(None) v.reset() # for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un} # P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1} # expectation of v is n / (n + 1) # for a synchronous environment, the following actions should take # about 7 * sleep * num / (num + 1) seconds # for async simulation, the analysis is complicated, but the time cost # should be smaller action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4) current_idx_start = 0 action = action_list[:num] env_ids = list(range(num)) o = [] spent_time = time.time() while current_idx_start < len(action_list): A, B, C, D = v.step(action=action, id=env_ids) b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D}) env_ids = b.info.env_id o.append(b) current_idx_start += len(action) # len of action may be smaller than len(A) in the end action = action_list[current_idx_start:current_idx_start + len(A)] # truncate env_ids with the first terms # typically len(env_ids) == len(A) == len(action), except for the # last batch when actions are not enough env_ids = env_ids[:len(action)] spent_time = time.time() - spent_time Batch.cat(o) v.close() # assure 1/7 improvement if sys.platform != "darwin": # macOS cannot pass this check assert spent_time < 6.0 * sleep * num / (num + 1)
def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = "model", input: str = "obs", **kwargs: Any, ) -> Batch: """Compute action over the given batch data. If you need to mask the action, please add a "mask" into batch.obs, for example, if we have an environment that has "0/1/2" three actions: :: batch == Batch( obs=Batch( obs="original obs, with batch_size=1 for demonstration", mask=np.array([[False, True, False]]), # action 1 is available # action 0 and 2 are unavailable ), ... ) :param float eps: in [0, 1], for epsilon-greedy exploration method. :return: A :class:`~tianshou.data.Batch` which has 3 keys: * ``act`` the action. * ``logits`` the network's raw output. * ``state`` the hidden state. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) obs = batch[input] obs_ = obs.obs if hasattr(obs, "obs") else obs q, h = model(obs_, state=state, info=batch.info) act: np.ndarray = to_numpy(q.max(dim=1)[1]) if hasattr(obs, "mask"): # some of actions are masked, they cannot be selected q_: np.ndarray = to_numpy(q) q_[~obs.mask] = -np.inf act = q_.argmax(axis=1) # add eps to act in training or testing phase if not self.updating and not np.isclose(self.eps, 0.0): for i in range(len(q)): if np.random.rand() < self.eps: q_ = np.random.rand(*q[i].shape) if hasattr(obs, "mask"): q_[~obs.mask[i]] = -np.inf act[i] = q_.argmax() return Batch(logits=q, act=act, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) self.optim.zero_grad() q = self(batch, eps=0.).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() c = torch.nn.SmoothL1Loss(reduction = 'none') # c = lambda r, q: (r-q).pow(2) td = c(r, q) loss = (td * weight).mean() batch.weight = loss # prio-buffer loss.backward() if self.grad_norm_clipping: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm_clipping) self.optim.step() self._cnt += 1 return {'loss': loss.item()}
def reset(self) -> None: """Reset all related variables in the collector.""" # use empty Batch for ``state`` so that ``self.data`` supports slicing # convert empty Batch to None when passing data to policy self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={}) # print("before reset env") self.reset_env() # print("after reset env") self.reset_buffer() self.reset_stat() if self._action_noise is not None: self._action_noise.reset()
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs) -> Batch: logits, h = self.model(batch.obs, state=state, info=batch.info) if self.mode == 'discrete': a = logits.max(dim=1)[1] else: a = logits return Batch(logits=logits, act=a, state=h)
def __init__(self, size: int, stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False) -> None: super().__init__() self._maxsize = size self._indices = np.arange(size) self._stack = None self.stack_num = stack_num self._avail = sample_avail and stack_num > 1 self._avail_index = [] self._save_s_ = not ignore_obs_next self._last_obs = save_only_last_obs self._index = 0 self._size = 0 self._meta = Batch() self.reset()
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) act = to_torch(batch.act, dtype=torch.long, device=batch.returns.device) q = self(batch).logits act_mask = torch.zeros_like(q) act_mask = act_mask.scatter_(-1, act.unsqueeze(-1), 1) act_q = q * act_mask returns = batch.returns returns = returns * act_mask td_error = returns - act_q loss = (td_error.pow(2).sum(-1).mean(-1) * weight).mean() batch.weight = td_error.sum(-1).sum(-1) # prio-buffer loss.backward() self.optim.step() self._iter += 1 return {"loss": loss.item()}
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch: """Pre-process the data from the provided replay buffer. Used in :meth:`update`. Check out :ref:`process_fn` for more information. """ # update reward with torch.no_grad(): batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten()) return super().process_fn(batch, buffer, indices)
def reset(self, reset_buffer: bool = True) -> None: """Reset the environment, statistics, current data and possibly replay memory. :param bool reset_buffer: if true, reset the replay buffer that is attached to the collector. """ # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy self.data = Batch(obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}) self.reset_env() if reset_buffer: self.reset_buffer() self.reset_stat()
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) q = self(batch).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns.flatten(), q) td = r - q loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer loss.backward() # Gradient clips torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optim.step() self._cnt += 1 for param_group in self.optim.param_groups: lr = param_group['lr'] return {"loss": loss.item(), "lr": lr}
def add( # type: ignore self, obs: Any, act: Any, rew: np.ndarray, done: np.ndarray, obs_next: Any = Batch(), info: Optional[Batch] = Batch(), policy: Optional[Batch] = Batch(), cached_buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, **kwargs: Any, ) -> Tuple[np.ndarray, np.ndarray]: """Add a batch of data into CachedReplayBuffer. Each of the data's length (first dimension) must equal to the length of cached_buffer_ids. By default the cached_buffer_ids is [0, 1, ..., cached_buffer_num - 1]. Return the array of episode_length and episode_reward with shape (len(cached_buffer_ids), ...), where (episode_length[i], episode_reward[i]) refers to the cached_buffer_ids[i]th cached buffer's corresponding episode result. """ if cached_buffer_ids is None: cached_buffer_ids = np.arange(self.cached_buffer_num) else: # make sure it is np.ndarray cached_buffer_ids = np.asarray(cached_buffer_ids) # in self.buffers, the first buffer is main_buffer buffer_ids = cached_buffer_ids + 1 # type: ignore result = super().add(obs, act, rew, done, obs_next, info, policy, buffer_ids=buffer_ids, **kwargs) # find the terminated episode, move data from cached buf to main buf for buffer_idx in cached_buffer_ids[np.asarray(done, np.bool_)]: self.main_buffer.update(self.cached_buffers[buffer_idx]) self.cached_buffers[buffer_idx].reset() return result
def get( self, index: Union[int, List[int], np.ndarray], key: str, default_value: Any = None, stack_num: Optional[int] = None, ) -> Union[Batch, np.ndarray]: """Return the stacked result. E.g., if you set ``key = "obs", stack_num = 4, index = t``, it returns the stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``. :param index: the index for getting stacked data. :param str key: the key to get, should be one of the reserved_keys. :param default_value: if the given key's data is not found and default_value is set, return this default_value. :param int stack_num: Default to self.stack_num. """ if key not in self._meta and default_value is not None: return default_value val = self._meta[key] if stack_num is None: stack_num = self.stack_num try: if stack_num == 1: # the most often case return val[index] stack: List[Any] = [] if isinstance(index, list): indices = np.array(index) else: indices = index # type: ignore for _ in range(stack_num): stack = [val[indices]] + stack indices = self.prev(indices) if isinstance(val, Batch): return Batch.stack(stack, axis=indices.ndim) else: return np.stack(stack, axis=indices.ndim) except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() return Batch()
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: if self._rew_norm: mean, std = batch.rew.mean(), batch.rew.std() if std > self.__eps: batch.rew = (batch.rew - mean) / std if self._lambda in [0, 1]: return self.compute_episodic_return(batch, None, gamma=self._gamma, gae_lambda=self._lambda) v_ = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False): v_.append(self.critic(b.obs_next)) v_ = torch.cat(v_, dim=0).cpu().numpy() return self.compute_episodic_return(batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: v_s_ = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_s_.append(to_numpy(self.critic(b.obs_next))) v_s_ = np.concatenate(v_s_, axis=0) if self._rew_norm: # unnormalize v_s_ v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean unnormalized_returns, _ = self.compute_episodic_return( batch, buffer, indice, v_s_=v_s_, gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ np.sqrt(self.ret_rms.var + self._eps) self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns return batch
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: weight = batch.pop('weight', 1.) current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q critic_loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() action = self(batch, explorating=False).act actor_loss = -self.critic(batch.obs, action).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() self.sync_weight() return { 'loss/actor': actor_loss.item(), 'loss/critic': critic_loss.item(), }
def forward(self, inputs, last_state=None, deterministic=False): obs = inputs.obs if not isinstance(obs, torch.Tensor): obs = torch.tensor(obs, device=self.device, dtype=torch.float32) value, actor_features, self.rnn_hxs = self.base(obs, self.rnn_hxs, self.masks) dist = self.dist(actor_features) if deterministic: action = dist.mode() else: action = dist.sample() return Batch(logits=dist.logits, act=action[0], state=None, dist=dist)
def test_batch(): batch = Batch(obs=[0], np=np.zeros([3, 4])) batch.update(obs=[1]) assert batch.obs == [1] batch.append(batch) assert batch.obs == [1, 1] assert batch.np.shape == (6, 4) assert batch[0].obs == batch[1].obs with pytest.raises(IndexError): batch[2] batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, permute=False)): assert b.obs == batch[i].obs
def __call__(self, batch, state=None, model='actor'): model = getattr(self, model) logits, h = model(batch.obs, state=state, info=batch.info) if isinstance(logits, tuple): dist = self.dist_fn(*logits) else: dist = self.dist_fn(logits) act = dist.sample() if self._range: act = act.clamp(self._range[0], self._range[1]) return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q critic_loss = (td.pow(2) * weight).mean() batch.weight = td # prio-buffer self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() action = self(batch).act actor_loss = -self.critic(batch.obs, action).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() self.sync_weight() return { "loss/actor": actor_loss.item(), "loss/critic": critic_loss.item(), }
def __init__( self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, action_noise: Optional[BaseNoise] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None, ) -> None: super().__init__() if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) self.env = env self.env_num = len(env) # environments that are available in step() # this means all environments in synchronous simulation # but only a subset of environments in asynchronous simulation self._ready_env_ids = np.arange(self.env_num) # self.async is a flag to indicate whether this collector works # with asynchronous simulation self.is_async = env.is_async # need cache buffers before storing in the main buffer self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn self.process_fn = policy.process_fn self._action_space = env.action_space self._action_noise = action_noise self._rew_metric = reward_metric or Collector._default_rew_metric # avoid creating attribute outside __init__ self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={}) self.reset()
def test_fn(size=2560): policy = PGPolicy(None, None, None, discount_factor=0.1) buf = ReplayBuffer(100) buf.add(1, 1, 1, 1, 1) fn = policy.process_fn # fn = compute_return_base batch = Batch( done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), ) batch = fn(batch, buf, 0) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) assert abs(batch.returns - ans).sum() <= 1e-5 batch = Batch( done=np.array([0, 1, 0, 1, 0, 1, 0.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) batch = fn(batch, buf, 0) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) assert abs(batch.returns - ans).sum() <= 1e-5 batch = Batch( done=np.array([0, 1, 0, 1, 0, 0, 1.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) batch = fn(batch, buf, 0) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) assert abs(batch.returns - ans).sum() <= 1e-5 if __name__ == '__main__': batch = Batch( done=np.random.randint(100, size=size) == 0, rew=np.random.random(size), ) cnt = 3000 t = time.time() for _ in range(cnt): compute_return_base(batch) print(f'vanilla: {(time.time() - t) / cnt}') t = time.time() for _ in range(cnt): policy.process_fn(batch, buf, 0) print(f'policy: {(time.time() - t) / cnt}')
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: r"""Compute the n-step return for Q-learning targets: .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a (Q_{new}(s_{t + n}, a))) , where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. If there is no target network, the :math:`Q_{old}` is equal to :math:`Q_{new}`. """ batch = self.compute_nstep_return(batch, buffer, indice, self._target_q, self._gamma, self._n_step) if isinstance(buffer, PrioritizedReplayBuffer): batch.update_weight = buffer.update_weight batch.indice = indice return batch
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) dist = Categorical(logits=logits) act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist)
def test_batch_from_to_numpy_without_copy(): batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) a_mem_addr_orig = batch.a.__array_interface__['data'][0] c_mem_addr_orig = batch.b.c.__array_interface__['data'][0] batch.to_torch() batch.to_numpy() a_mem_addr_new = batch.a.__array_interface__['data'][0] c_mem_addr_new = batch.b.c.__array_interface__['data'][0] assert a_mem_addr_new == a_mem_addr_orig assert c_mem_addr_new == c_mem_addr_orig
def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): # calculate loss for actor dist = self(b).dist ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) surr1 = ratio * b.adv surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv if self._dual_clip: clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean() else: clip_loss = -torch.min(surr1, surr2).mean() # calculate loss for critic value = self.critic(b.obs).flatten() if self._value_clip: v_clip = b.v_s + (value - b.v_s).clamp( -self._eps_clip, self._eps_clip) vf1 = (b.returns - value).pow(2) vf2 = (b.returns - v_clip).pow(2) vf_loss = 0.5 * torch.max(vf1, vf2).mean() else: vf_loss = 0.5 * (b.returns - value).pow(2).mean() # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = clip_loss + self._weight_vf * vf_loss \ - self._weight_ent * ent_loss self.optim.zero_grad() loss.backward() if self._grad_norm is not None: # clip large gradient nn.utils.clip_grad_norm_(list(self.actor.parameters()) + list(self.critic.parameters()), max_norm=self._grad_norm) self.optim.step() clip_losses.append(clip_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: self.lr_scheduler.step() return { "loss": losses, "loss/clip": clip_losses, "loss/vf": vf_losses, "loss/ent": ent_losses, }
def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs: Any, ) -> Batch: logits, hidden = self.model(batch.obs, state=state, info=batch.info) if self.action_type == "discrete": act = logits.max(dim=1)[1] else: act = logits return Batch(logits=logits, act=act, state=hidden)
def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch: return Batch( obs=self.get(index, 'obs'), act=self.act[index], rew=self.rew[index], done=self.done[index], obs_next=self.get(index, 'obs_next'), info=self.get(index, 'info'), policy=self.get(index, 'policy'), weight=self.weight[index], )
def add( self, obs: Any, act: Any, rew: Union[Number, np.number, np.ndarray], done: Union[Number, np.number, np.bool_], obs_next: Any = None, info: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {}, **kwargs: Any, ) -> Tuple[int, Union[float, np.ndarray]]: """Add a batch of data into replay buffer. Return (episode_length, episode_reward) if one episode is terminated, otherwise return (0, 0.0). """ assert isinstance( info, (dict, Batch )), "You should return a dict in the last argument of env.step()." if self._save_only_last_obs: obs = obs[-1] self._add_to_buffer("obs", obs) self._add_to_buffer("act", act) # make sure the data type of reward is float instead of int # but rew may be np.ndarray, so that we cannot use float(rew) rew = rew * 1.0 # type: ignore self._add_to_buffer("rew", rew) self._add_to_buffer("done", bool(done)) # done should be a bool scalar if self._save_obs_next: if obs_next is None: obs_next = Batch() elif self._save_only_last_obs: obs_next = obs_next[-1] self._add_to_buffer("obs_next", obs_next) self._add_to_buffer("info", info) self._add_to_buffer("policy", policy) if self.maxsize > 0: self._size = min(self._size + 1, self.maxsize) self._index = (self._index + 1) % self.maxsize else: # TODO: remove this after deleting ListReplayBuffer self._size = self._index = self._size + 1 self._episode_reward += rew self._episode_length += 1 if done: result = self._episode_length, self._episode_reward self._episode_length, self._episode_reward = 0, 0.0 return result else: return 0, self._episode_reward * 0.0
def test_nstep_returns(size=10000): buf = ReplayBuffer(10) for i in range(12): buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3)) batch, indices = buf.sample(0) assert np.allclose(indices, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = to_numpy( BasePolicy.compute_nstep_return( batch, buf, indices, target_q_fn, gamma=.1, n_step=1 ).pop('returns').reshape(-1) ) assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, .1, buf, indices) assert np.allclose(returns, r_), (r_, returns) returns_multidim = to_numpy( BasePolicy.compute_nstep_return( batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1 ).pop('returns') ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy( BasePolicy.compute_nstep_return( batch, buf, indices, target_q_fn, gamma=.1, n_step=2 ).pop('returns').reshape(-1) ) assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, .1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( BasePolicy.compute_nstep_return( batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2 ).pop('returns') ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy( BasePolicy.compute_nstep_return( batch, buf, indices, target_q_fn, gamma=.1, n_step=10 ).pop('returns').reshape(-1) ) assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, .1, buf, indices) assert np.allclose(returns, r_) returns_multidim = to_numpy( BasePolicy.compute_nstep_return( batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10 ).pop('returns') ) assert np.allclose(returns_multidim, returns[:, np.newaxis])