def get( self, index: Union[int, np.integer, np.ndarray], key: str, stack_num: Optional[int] = None, ) -> Union[Batch, np.ndarray]: """Return the stacked result. E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index. """ if stack_num is None: stack_num = self.stack_num val = self._meta[key] try: if stack_num == 1: # the most often case return val[index] stack: List[Any] = [] indice = np.asarray(index) for _ in range(stack_num): stack = [val[indice]] + stack indice = self.prev(indice) if isinstance(val, Batch): return Batch.stack(stack, axis=indice.ndim) else: return np.stack(stack, axis=indice.ndim) except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() return Batch()
def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs: Any, ) -> Batch: """Compute the random action over the given batch data. The input should contain a mask in batch.obs, with "True" to be available and "False" to be unavailable. For example, ``batch.obs.mask == np.array([[False, True, False]])`` means with batch size 1, action "1" is available but action "0" and "2" are unavailable. :return: A :class:`~tianshou.data.Batch` with "act" key, containing the random action. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ mask = batch.obs.mask logits = np.random.rand(*mask.shape) logits[~mask] = -np.inf return Batch(act=logits.argmax(axis=-1))
def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs: Any, ) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 4 keys: * ``act`` the action. * ``logits`` the network's raw output. * ``dist`` the action distribution. * ``state`` the hidden state. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ logits, hidden = self.actor(batch.obs, state=state) if isinstance(logits, tuple): dist = self.dist_fn(*logits) else: dist = self.dist_fn(logits) if self._deterministic_eval and not self.training: if self.action_type == "discrete": act = logits.argmax(-1) elif self.action_type == "continuous": act = logits[0] else: act = dist.sample() return Batch(logits=logits, act=act, state=hidden, dist=dist)
def rollout_render(self, env, T, high_len): """ rollot and render agent on env for T time steps return a np series of rendered images dimensions are time, channels, width, height this does not store information in memory """ video = [] state = env.reset() state = self.rms.filter(state) video.append(env.render(mode="rgb_array")) t = 0 done = False while t < T and not done: high_action, _, _ = self.high.actor.action(state) for _ in range(high_len): if self.disc: state = np.reshape(state, (1, -1)) batch = Batch(obs=state, info={}) low_action = self.low[high_action].act( batch, str(high_action), False) else: low_action, _, _ = self.low[high_action].actor.action( state) state, _, done, _ = env.step(low_action) state = self.rms.filter(state) video.append(env.render(mode="rgb_array")) t += 1 if done: break return np.transpose(np.array(video), (0, 3, 1, 2))
def sample(self, batch_size: int) -> Batch: """Sample a data batch from the internal replay buffer. It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. :param int batch_size: ``0`` means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size. """ if self._multi_buf: if batch_size > 0: lens = [len(b) for b in self.buffer] total = sum(lens) batch_index = np.random.choice( len(self.buffer), batch_size, p=np.array(lens) / total) else: batch_index = np.array([]) batch_data = Batch() for i, b in enumerate(self.buffer): cur_batch = (batch_index == i).sum() if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) batch_data.append(batch) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias y = self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) if self._noise is not None and not self.updating: act += to_torch_as(self._noise(act.shape), act) #act = act.clamp(self._range[0], self._range[1]) for i, (low, high) in enumerate(zip(self._range[0], self._range[1])): act[:, i] = act.clone()[:, i].clamp(low, high) return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: act = logits[0] else: act = dist.rsample() log_prob = dist.log_prob(act).unsqueeze(-1) # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. if self.action_scaling and self.action_space is not None: action_scale = to_torch_as( (self.action_space.high - self.action_space.low) / 2.0, act) else: action_scale = 1.0 # type: ignore squashed_action = torch.tanh(act) log_prob = log_prob - torch.log(action_scale * (1 - squashed_action.pow(2)) + self.__eps).sum(-1, keepdim=True) return Batch(logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob)
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", **kwargs: Any, ) -> Batch: obs = batch[input] logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) dist = Independent(Normal(*logits), 1) if self._deterministic_eval and not self.training: x = logits[0] else: x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias # __eps is used to avoid log of zero/negative number. y = self._action_scale * (1 - y.pow(2)) + self.__eps # Compute logprob from Gaussian, and then apply correction for Tanh squashing. # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def preprocess_fn(self, **kwargs): # modify info before adding into the buffer, and recorded into tfb # if only obs exist -> reset # if obs/act/rew/done/... exist -> normal step if 'rew' in kwargs: n = len(kwargs['obs']) info = kwargs['info'] for i in range(n): info[i].update(rew=kwargs['rew'][i]) if 'key' in info.keys(): self.writer.add_scalar('key', np.mean( info['key']), global_step=self.cnt) self.cnt += 1 return Batch(info=info) else: return Batch()
def _make_tianshou_batch( o_t: tf.Tensor, a_t: tf.Tensor, r_t: tf.Tensor, d_t: tf.Tensor, o_tp1: tf.Tensor, a_tp1: tf.Tensor, ) -> Batch: """Create Tianshou batch with offline data. Args: o_t: Observation at time t. a_t: Action at time t. r_t: Reward at time t. d_t: Discount at time t. o_tp1: Observation at time t+1. a_tp1: Action at time t+1. Returns: A tianshou.data.Batch object. """ return Batch(obs=o_t.numpy(), act=a_t.numpy(), rew=r_t.numpy(), done=1 - d_t.numpy(), obs_next=o_tp1.numpy())
def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.adarray]] = None, **kwargs: Any, ) -> Batch: """ Compute action over the given batch data stochastic action distribution Return: A:class: 'tianshou.data.Batch' which has 4 keys: * "act" the action * "logits" the network's raw output * "state" the hidden state """ # v_value = self.value_net(obs, state=stategamma, info=batch.info) logits, h = self.policy_net(batch.obs, state=state, info=batch.info) #if self.mode == "discrete": # action = logits(dim =1)[1] #else: # action = logits if isinstance(logits, tuple): dist = self.dist_fn(*logits) else: dist = self.dist_fn(logits) act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist)
def forward( # type: ignore self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", eps: Optional[float] = None, **kwargs: Any, ) -> Batch: if eps is None: eps = self._eps obs = batch[input] q_value, state = self.model(obs, state=state, info=batch.info) imitation_logits, _ = self.imitator(obs, state=state, info=batch.info) # mask actions for argmax ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values mask = (ratio < self._log_tau).float() action = (q_value - np.inf * mask).argmax(dim=-1) # add eps to act if not np.isclose(eps, 0.0): bsz, action_num = q_value.shape mask = np.random.rand(bsz) < eps action_rand = torch.randint(action_num, size=[bsz], device=action.device) action[mask] = action_rand[mask] return Batch(act=action, state=state, q_value=q_value, imitation_logits=imitation_logits)
def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs: Any, ) -> Batch: """Compute action over the given batch data.""" # There is "obs" in the Batch # obs_group: several groups. Each group has a state. obs_group: torch.Tensor = to_torch(batch.obs, device=self.device) act_group = [] for obs in obs_group: # now obs is (state_dim) obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1) # now obs is (forward_sampled_times, state_dim) # decode(obs) generates action and actor perturbs it act = self.actor(obs, self.vae.decode(obs)) # now action is (forward_sampled_times, action_dim) q1 = self.critic1(obs, act) # q1 is (forward_sampled_times, 1) max_indice = q1.argmax(0) act_group.append(act[max_indice].cpu().data.numpy().flatten()) act_group = np.array(act_group) return Batch(act=act_group)
def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = "model", input: str = "obs", **kwargs: Any, ) -> Batch: if model == "model_old": sample_size = self._target_sample_size elif self.training: sample_size = self._online_sample_size else: sample_size = self._sample_size model = getattr(self, model) obs = batch[input] obs_ = obs.obs if hasattr(obs, "obs") else obs (logits, taus), h = model( obs_, sample_size=sample_size, state=state, info=batch.info ) q = self.compute_q_value(logits, getattr(obs, "mask", None)) if not hasattr(self, "max_action_num"): self.max_action_num = q.shape[1] act = to_numpy(q.max(dim=1)[1]) return Batch(logits=logits, act=act, state=h, taus=taus)
def preprocess_fn(obs=None, act=None, rew=None, done=None, obs_next=None, info=None, policy=None): if obs_next is not None: obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:])) obs_next = np.moveaxis(obs_next, 0, -1) obs_next = cv2.resize(obs_next, (SIZE, SIZE)) obs_next = np.asanyarray(obs_next, dtype=np.uint8) obs_next = np.reshape(obs_next, (-1, FRAME, SIZE, SIZE)) obs_next = np.moveaxis(obs_next, 1, -1) elif obs is not None: obs = np.reshape(obs, (-1, *obs.shape[2:])) obs = np.moveaxis(obs, 0, -1) obs = cv2.resize(obs, (SIZE, SIZE)) obs = np.asanyarray(obs, dtype=np.uint8) obs = np.reshape(obs, (-1, FRAME, SIZE, SIZE)) obs = np.moveaxis(obs, 1, -1) return Batch(obs=obs, act=act, rew=rew, done=done, obs_next=obs_next, info=info)
def forward(self, batch, state=None): logits, h = self.model(batch.obs, state=state) if self.mode == 'discrete': a = logits.max(dim=1)[1] else: a = logits return Batch(logits=logits, act=a, state=h)
def forward( self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = "actor", input: str = "obs", **kwargs: Any, ) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 2 keys: * ``act`` the action. * ``state`` the hidden state. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) obs = batch[input] actions, h = model(obs, state=state, info=batch.info) if self._noise and not self.updating: actions += to_torch_as(self._noise(actions.shape), actions) actions *= self._action_scale actions += self._action_bias for i, (low, high) in enumerate(zip(self._range[0], self._range[1])): actions[:, i] = actions.clone()[:, i].clamp(low, high) # actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h)
def test_utils_to_torch(): batch = Batch( a=np.ones((1,), dtype=np.float64), b=Batch( c=np.ones((1,), dtype=np.float64), d=torch.ones((1,), dtype=torch.float64) ) ) a_torch_float = to_torch(batch.a, dtype=torch.float32) assert a_torch_float.dtype == torch.float32 a_torch_double = to_torch(batch.a, dtype=torch.float64) assert a_torch_double.dtype == torch.float64 batch_torch_float = to_torch(batch, dtype=torch.float32) assert batch_torch_float.a.dtype == torch.float32 assert batch_torch_float.b.c.dtype == torch.float32 assert batch_torch_float.b.d.dtype == torch.float32
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = 'actor', input: str = 'obs', explorating: bool = True, **kwargs) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 2 keys: * ``act`` the action. * ``state`` the hidden state. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) obs = getattr(batch, input) actions, h = model(obs, state=state, info=batch.info) actions += self._action_bias if self.training and explorating: actions += to_torch_as(self._noise(actions.shape), actions) actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h)
def __init__( self, size: int, stack_num: int = 1, ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, **kwargs: Any, # otherwise PrioritizedVectorReplayBuffer will cause TypeError ) -> None: self.options: Dict[str, Any] = { "stack_num": stack_num, "ignore_obs_next": ignore_obs_next, "save_only_last_obs": save_only_last_obs, "sample_avail": sample_avail, } super().__init__() self.maxsize = size assert stack_num > 0, "stack_num should be greater than 0" self.stack_num = stack_num self._indices = np.arange(size) self._save_obs_next = not ignore_obs_next self._save_only_last_obs = save_only_last_obs self._sample_avail = sample_avail self._meta: Batch = Batch() self.reset()
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = 'model', input: str = 'obs', eps: Optional[float] = None, **kwargs) -> Batch: """Compute action over the given batch data. :param float eps: in [0, 1], for epsilon-greedy exploration method. :return: A :class:`~tianshou.data.Batch` which has 3 keys: * ``act`` the action. * ``logits`` the network's raw output. * ``state`` the hidden state. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) obs = getattr(batch, input) q, h = model(obs, state=state, info=batch.info) act = q.max(dim=1)[1].detach().cpu().numpy() # add eps to act if eps is None: eps = self.eps if not np.isclose(eps, 0): for i in range(len(q)): if np.random.rand() < eps: act[i] = np.random.randint(q.shape[1]) return Batch(logits=q, act=act, state=h)
def test_hdf5(): size = 100 buffers = { "array": ReplayBuffer(size, stack_num=2), "list": ListReplayBuffer(), "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4), } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = 'cuda' if torch.cuda.is_available() else 'cpu' info_t = torch.tensor([1.]).to(device) for i in range(4): kwargs = { 'obs': Batch(index=np.array([i])), 'act': i, 'rew': np.array([1, 2]), 'done': i % 3 == 2, 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, } buffers["array"].add(**kwargs) buffers["list"].add(**kwargs) buffers["prioritized"].add(weight=np.random.rand(), **kwargs) # save paths = {} for k, buf in buffers.items(): f, path = tempfile.mkstemp(suffix='.hdf5') os.close(f) buf.save_hdf5(path) paths[k] = path # load replay buffer _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()} # compare for k in buffers.keys(): assert len(_buffers[k]) == len(buffers[k]) assert np.allclose(_buffers[k].act, buffers[k].act) assert _buffers[k].stack_num == buffers[k].stack_num assert _buffers[k].maxsize == buffers[k].maxsize assert np.all(_buffers[k]._indices == buffers[k]._indices) for k in ["array", "prioritized"]: assert _buffers[k]._index == buffers[k]._index assert isinstance(buffers[k].get(0, "info"), Batch) assert isinstance(_buffers[k].get(0, "info"), Batch) for k in ["array"]: assert np.all( buffers[k][:].info.number.n == _buffers[k][:].info.number.n) assert np.all( buffers[k][:].info.extra == _buffers[k][:].info.extra) # raise exception when value cannot be pickled data = {"not_supported": lambda x: x * x} grp = h5py.Group with pytest.raises(NotImplementedError): to_hdf5(data, grp) # ndarray with data type not supported by HDF5 that cannot be pickled data = {"not_supported": np.array(lambda x: x * x)} grp = h5py.Group with pytest.raises(RuntimeError): to_hdf5(data, grp)
def forward(self, batch, state=None, input='obs', **kwargs): obs = getattr(batch, input) logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) mu, sigma = logits log_prob = None dist = None if kwargs.get('deterministic', False): act = torch.tanh(mu) else: dist = torch.distributions.Normal(*logits) x = dist.rsample() y = torch.tanh(x) log_prob = (dist.log_prob(x) - torch.log(self._action_scale * (1 - y.pow(2)) + self.__eps)).sum( -1, keepdim=True) act = y act = act.clamp(self._range[0], self._range[1]) return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def simulate_environment(self): self.simulation_env = SimulationEnv(self.args, self.simulator) obs, act, rew, done, info = [], [], [], [], [] obs.append(self.simulation_env.reset()) for i in range(self.args.n_simulator_step): with torch.no_grad(): act.append(self(Batch(obs=obs[-1], info={})).act.cpu().numpy()) result = self.simulation_env.step(act[-1]) obs.append(result[0]) rew.append(result[1]) done.append(result[2]) info.append(result[3]) obs_next = np.array(obs[1:]) obs = np.array(obs[:-1]) act = np.array(act) rew = np.array(rew) done = np.array(done) # obs = obs.reshape(-1, obs.shape[-1]) # act = act.reshape(-1, act.shape[-1]) # rew = np.array(rew).reshape(-1) # done = np.array(done).reshape(-1) # obs_next = obs_next.reshape(-1, obs_next.shape[-1]) # rew = rew.reshape(obs.shape[0], obs.shape[1]) for j in range(obs.shape[1]): for i in range(self.args.n_simulator_step): self.simulator_buffer.add(obs[i, j], act[i, j], rew[i, j], done[i, j], obs_next[i, j]) return None
def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7): env_fns = [ lambda: MyTestEnv(size=size, sleep=sleep * 2), lambda: MyTestEnv(size=size, sleep=sleep * 3), lambda: MyTestEnv(size=size, sleep=sleep * 5), lambda: MyTestEnv(size=size, sleep=sleep * 7) ] test_cls = [SubprocVectorEnv, ShmemVectorEnv] if has_ray(): test_cls += [RayVectorEnv] for cls in test_cls: v = cls(env_fns, wait_num=num - 1, timeout=timeout) v.reset() expect_result = [ [0, 1], [0, 1, 2], [0, 1, 3], [0, 1, 2], [0, 1], [0, 2, 3], [0, 1], ] ids = np.arange(num) for res in expect_result: t = time.time() _, _, _, info = v.step([1] * len(ids), ids) t = time.time() - t ids = Batch(info).env_id print(ids, t) if cls != RayVectorEnv: # ray-project/ray#10134 assert np.allclose(sorted(ids), res) assert (t < timeout) == (len(res) == num - 1)
def policy_forward(policy, obs, info=None, eps=0.0): """ Map the observation to the action under the policy, Parameters ---------- policy: a trained tianshou ddpg policy obs: array_like observation info: gym info eps: float The predicted action is extracted from an Gaussian distribution, eps*I is the covariance """ obs = np.array(obs) obs_len = 1 if obs.ndim == 1: obs_len = 1 elif obs.ndim == 2: obs_len = len(obs) obs = obs.reshape((obs_len, -1)) batch = Batch(obs=obs, info=None) batch = policy(batch, eps=eps) act = batch.act.detach().cpu().numpy() return act
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 4 keys: * ``act`` the action. * ``logits`` the network's raw output. * ``dist`` the action distribution. * ``state`` the hidden state. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ logits, h = self.actor(batch.obs, state=state, info=batch.info) if isinstance(logits, tuple): dist = self.dist_fn(*logits) else: dist = self.dist_fn(logits) act = dist.sample() if self._range: act = act.clamp(self._range[0], self._range[1]) return Batch(logits=logits, act=act, state=h, dist=dist)
def __call__(self, batch, state=None, model='actor', input='obs', eps=None, **kwargs): """Compute action over the given batch data. :param float eps: in [0, 1], for exploration use. :return: A :class:`~tianshou.data.Batch` which has 2 keys: * ``act`` the action. * ``state`` the hidden state. More information can be found at :meth:`~tianshou.policy.BasePolicy.__call__`. """ model = getattr(self, model) obs = getattr(batch, input) logits, h = model(obs, state=state, info=batch.info) logits += self._action_bias if eps is None: eps = self._eps if eps > 0: # noise = np.random.normal(0, eps, size=logits.shape) # logits += torch.tensor(noise, device=logits.device) # noise = self.noise(logits.shape, eps) logits += torch.randn(size=logits.shape, device=logits.device) * eps logits = logits.clamp(self._range[0], self._range[1]) return Batch(act=logits, state=h)
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = 'actor', input: str = 'obs', eps: Optional[float] = None, **kwargs) -> Batch: """Compute action over the given batch data. :param float eps: in [0, 1], for exploration use. :return: A :class:`~tianshou.data.Batch` which has 2 keys: * ``act`` the action. * ``state`` the hidden state. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) obs = getattr(batch, input) logits, h = model(obs, state=state, info=batch.info) logits += self._action_bias if eps is None: eps = self._eps if eps > 0: # noise = np.random.normal(0, eps, size=logits.shape) # logits += to_torch(noise, device=logits.device) # noise = self.noise(logits.shape, eps) logits += torch.randn(size=logits.shape, device=logits.device) * eps logits = logits.clamp(self._range[0], self._range[1]) return Batch(act=logits, state=h)
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, model: str = 'model', input: str = 'obs', eps: Optional[float] = None, **kwargs) -> Batch: """Compute action over the given batch data. If you need to mask the action, please add a "mask" into batch.obs, for example, if we have an environment that has "0/1/2" three actions: :: batch == Batch( obs=Batch( obs="original obs, with batch_size=1 for demonstration", mask=np.array([[False, True, False]]), # action 1 is available # action 0 and 2 are unavailable ), ... ) :param float eps: in [0, 1], for epsilon-greedy exploration method. :return: A :class:`~tianshou.data.Batch` which has 3 keys: * ``act`` the action. * ``logits`` the network's raw output. * ``state`` the hidden state. .. seealso:: Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ model = getattr(self, model) obs = getattr(batch, input) obs_ = obs.obs if hasattr(obs, 'obs') else obs # print(type(obs_)) # print(model) q, h = model(obs_, state=state, info=batch.info) act = to_numpy(q.max(dim=1)[1]) has_mask = hasattr(obs, 'mask') if has_mask: # some of actions are masked, they cannot be selected q_ = to_numpy(q) q_[~obs.mask] = -np.inf act = q_.argmax(axis=1) # add eps to act if eps is None: eps = self.eps if not np.isclose(eps, 0): for i in range(len(q)): if np.random.rand() < eps: q_ = np.random.rand(*q[i].shape) if has_mask: q_[~obs.mask[i]] = -np.inf act[i] = q_.argmax() return Batch(logits=q, act=act, state=h)