def __init__(self, policy, env, buffer=None, stat_size=100): super().__init__() self.env = env self.env_num = 1 self.collect_step = 0 self.collect_episode = 0 self.collect_time = 0 if buffer is None: self.buffer = ReplayBuffer(100) else: self.buffer = buffer self.policy = policy self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) self._multi_buf = False # True if buf is a list # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: self.env_num = len(env) if isinstance(self.buffer, list): assert len(self.buffer) == self.env_num, \ 'The number of data buffer does not match the number of ' \ 'input env.' self._multi_buf = True elif isinstance(self.buffer, ReplayBuffer): self._cached_buf = [ ListReplayBuffer() for _ in range(self.env_num) ] else: raise TypeError('The buffer in data collector is invalid!') self.reset_env() self.reset_buffer() # state over batch is either a list, an np.ndarray, or a torch.Tensor self.state = None self.step_speed = MovAvg(stat_size) self.episode_speed = MovAvg(stat_size)
def __init__( self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], obs_adv_atk: Attack, buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, reward_metric: Optional[Callable[[np.ndarray], float]] = None, atk_frequency: float = 0.5, test: bool = False, device: str = 'cuda' if torch.cuda.is_available() else 'cpu' ) -> None: super().__init__() if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) self.env = env self.env_num = len(env) self.device = device self.obs_adv_atk = obs_adv_atk self.obs_adv_atk.targeted = False self.atk_frequency = atk_frequency self.test = test # environments that are available in step() # this means all environments in synchronous simulation # but only a subset of environments in asynchronous simulation self._ready_env_ids = np.arange(self.env_num) # need cache buffers before storing in the main buffer self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn self.process_fn = policy.process_fn self._action_space = env.action_space self._rew_metric = reward_metric or adversarial_training_collector._default_rew_metric # avoid creating attribute outside __init__ self.reset()
def test_init(): for _ in np.arange(1e5): _ = ReplayBuffer(1e5) _ = PrioritizedReplayBuffer(size=int(1e5), alpha=0.5, beta=0.5) _ = ListReplayBuffer()
def test_hdf5(): size = 100 buffers = { "array": ReplayBuffer(size, stack_num=2), "list": ListReplayBuffer(), "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4) } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = 'cuda' if torch.cuda.is_available() else 'cpu' rew = torch.tensor([1.]).to(device) for i in range(4): kwargs = { 'obs': Batch(index=np.array([i])), 'act': i, 'rew': rew, 'done': 0, 'info': { "number": { "n": i }, 'extra': None }, } buffers["array"].add(**kwargs) buffers["list"].add(**kwargs) buffers["prioritized"].add(weight=np.random.rand(), **kwargs) # save paths = {} for k, buf in buffers.items(): f, path = tempfile.mkstemp(suffix='.hdf5') os.close(f) buf.save_hdf5(path) paths[k] = path # load replay buffer _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()} # compare for k in buffers.keys(): assert len(_buffers[k]) == len(buffers[k]) assert np.allclose(_buffers[k].act, buffers[k].act) assert _buffers[k].stack_num == buffers[k].stack_num assert _buffers[k]._maxsize == buffers[k]._maxsize assert _buffers[k]._index == buffers[k]._index assert np.all(_buffers[k]._indices == buffers[k]._indices) for k in ["array", "prioritized"]: assert isinstance(buffers[k].get(0, "info"), Batch) assert isinstance(_buffers[k].get(0, "info"), Batch) for k in ["array"]: assert np.all( buffers[k][:].info.number.n == _buffers[k][:].info.number.n) assert np.all(buffers[k][:].info.extra == _buffers[k][:].info.extra) for path in paths.values(): os.remove(path) # raise exception when value cannot be pickled data = {"not_supported": lambda x: x * x} grp = h5py.Group with pytest.raises(NotImplementedError): to_hdf5(data, grp) # ndarray with data type not supported by HDF5 that cannot be pickled data = {"not_supported": np.array(lambda x: x * x)} grp = h5py.Group with pytest.raises(RuntimeError): to_hdf5(data, grp)