Example #1
0
 def __init__(self, policy, env, buffer=None, stat_size=100):
     super().__init__()
     self.env = env
     self.env_num = 1
     self.collect_step = 0
     self.collect_episode = 0
     self.collect_time = 0
     if buffer is None:
         self.buffer = ReplayBuffer(100)
     else:
         self.buffer = buffer
     self.policy = policy
     self.process_fn = policy.process_fn
     self._multi_env = isinstance(env, BaseVectorEnv)
     self._multi_buf = False  # True if buf is a list
     # need multiple cache buffers only if storing in one buffer
     self._cached_buf = []
     if self._multi_env:
         self.env_num = len(env)
         if isinstance(self.buffer, list):
             assert len(self.buffer) == self.env_num, \
                 'The number of data buffer does not match the number of ' \
                 'input env.'
             self._multi_buf = True
         elif isinstance(self.buffer, ReplayBuffer):
             self._cached_buf = [
                 ListReplayBuffer() for _ in range(self.env_num)
             ]
         else:
             raise TypeError('The buffer in data collector is invalid!')
     self.reset_env()
     self.reset_buffer()
     # state over batch is either a list, an np.ndarray, or a torch.Tensor
     self.state = None
     self.step_speed = MovAvg(stat_size)
     self.episode_speed = MovAvg(stat_size)
Example #2
0
 def __init__(
         self,
         policy: BasePolicy,
         env: Union[gym.Env, BaseVectorEnv],
         obs_adv_atk: Attack,
         buffer: Optional[ReplayBuffer] = None,
         preprocess_fn: Optional[Callable[..., Batch]] = None,
         reward_metric: Optional[Callable[[np.ndarray], float]] = None,
         atk_frequency: float = 0.5,
         test: bool = False,
         device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
 ) -> None:
     super().__init__()
     if not isinstance(env, BaseVectorEnv):
         env = DummyVectorEnv([lambda: env])
     self.env = env
     self.env_num = len(env)
     self.device = device
     self.obs_adv_atk = obs_adv_atk
     self.obs_adv_atk.targeted = False
     self.atk_frequency = atk_frequency
     self.test = test
     # environments that are available in step()
     # this means all environments in synchronous simulation
     # but only a subset of environments in asynchronous simulation
     self._ready_env_ids = np.arange(self.env_num)
     # need cache buffers before storing in the main buffer
     self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
     self.buffer = buffer
     self.policy = policy
     self.preprocess_fn = preprocess_fn
     self.process_fn = policy.process_fn
     self._action_space = env.action_space
     self._rew_metric = reward_metric or adversarial_training_collector._default_rew_metric
     # avoid creating attribute outside __init__
     self.reset()
Example #3
0
def test_init():
    for _ in np.arange(1e5):
        _ = ReplayBuffer(1e5)
        _ = PrioritizedReplayBuffer(size=int(1e5), alpha=0.5, beta=0.5)
        _ = ListReplayBuffer()
Example #4
0
def test_hdf5():
    size = 100
    buffers = {
        "array": ReplayBuffer(size, stack_num=2),
        "list": ListReplayBuffer(),
        "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4)
    }
    buffer_types = {k: b.__class__ for k, b in buffers.items()}
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    rew = torch.tensor([1.]).to(device)
    for i in range(4):
        kwargs = {
            'obs': Batch(index=np.array([i])),
            'act': i,
            'rew': rew,
            'done': 0,
            'info': {
                "number": {
                    "n": i
                },
                'extra': None
            },
        }
        buffers["array"].add(**kwargs)
        buffers["list"].add(**kwargs)
        buffers["prioritized"].add(weight=np.random.rand(), **kwargs)

    # save
    paths = {}
    for k, buf in buffers.items():
        f, path = tempfile.mkstemp(suffix='.hdf5')
        os.close(f)
        buf.save_hdf5(path)
        paths[k] = path

    # load replay buffer
    _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}

    # compare
    for k in buffers.keys():
        assert len(_buffers[k]) == len(buffers[k])
        assert np.allclose(_buffers[k].act, buffers[k].act)
        assert _buffers[k].stack_num == buffers[k].stack_num
        assert _buffers[k]._maxsize == buffers[k]._maxsize
        assert _buffers[k]._index == buffers[k]._index
        assert np.all(_buffers[k]._indices == buffers[k]._indices)
    for k in ["array", "prioritized"]:
        assert isinstance(buffers[k].get(0, "info"), Batch)
        assert isinstance(_buffers[k].get(0, "info"), Batch)
    for k in ["array"]:
        assert np.all(
            buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
        assert np.all(buffers[k][:].info.extra == _buffers[k][:].info.extra)

    for path in paths.values():
        os.remove(path)

    # raise exception when value cannot be pickled
    data = {"not_supported": lambda x: x * x}
    grp = h5py.Group
    with pytest.raises(NotImplementedError):
        to_hdf5(data, grp)
    # ndarray with data type not supported by HDF5 that cannot be pickled
    data = {"not_supported": np.array(lambda x: x * x)}
    grp = h5py.Group
    with pytest.raises(RuntimeError):
        to_hdf5(data, grp)