def test_batch_cat_and_stack(): b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b12_cat_out = Batch.cat((b1, b2)) b12_cat_in = copy.deepcopy(b1) b12_cat_in.cat_(b2) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert isinstance(b12_cat_in.a.d.e, np.ndarray) assert b12_cat_in.a.d.e.ndim == 1 b12_stack = Batch.stack((b1, b2)) assert isinstance(b12_stack.a.d.e, np.ndarray) assert b12_stack.a.d.e.ndim == 2 b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]])) b34_stack = Batch.stack((b3, b4), axis=1) assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d)))) b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}}, {'a': True, 'b': {'c': 3.0}}]) b5 = Batch(b5_dict) assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True) assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[1] == 0.0
def test_batch_cat_and_stack_and_empty(): b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b12_cat_out = Batch.cat((b1, b2)) b12_cat_in = copy.deepcopy(b1) b12_cat_in.cat_(b2) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert isinstance(b12_cat_in.a.d.e, np.ndarray) assert b12_cat_in.a.d.e.ndim == 1 b12_stack = Batch.stack((b1, b2)) assert isinstance(b12_stack.a.d.e, np.ndarray) assert b12_stack.a.d.e.ndim == 2 b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]])) b34_stack = Batch.stack((b3, b4), axis=1) assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d)))) b5_dict = np.array([{ 'a': False, 'b': { 'c': 2.0, 'd': 1.0 } }, { 'a': True, 'b': { 'c': 3.0 } }]) b5 = Batch(b5_dict) assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True) assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[1] == 0.0 b5[1] = Batch.empty(b5[0]) assert np.allclose(b5.a, [False, False]) assert np.allclose(b5.b.c, [2, 0]) assert np.allclose(b5.b.d, [1, 0]) data = Batch(a=[False, True], b={ 'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')] }, c=np.array([1, 3, 4], dtype=np.int), t=torch.tensor([4, 5, 6, 7.])) data[-1] = Batch.empty(data[1]) assert np.allclose(data.c, [1, 3, 0]) assert np.allclose(data.a, [False, False]) assert list(data.b.c) == ['2.0', ''] assert list(data.b.d) == [1, None] assert np.allclose(data.b.e, [2, 0]) assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.])) b0 = Batch() b0.empty_() assert b0.shape == []
def get( self, index: Union[int, np.integer, np.ndarray], key: str, stack_num: Optional[int] = None, ) -> Union[Batch, np.ndarray]: """Return the stacked result. E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index. """ if stack_num is None: stack_num = self.stack_num val = self._meta[key] try: if stack_num == 1: # the most often case return val[index] stack: List[Any] = [] indice = np.asarray(index) for _ in range(stack_num): stack = [val[indice]] + stack indice = self.prev(indice) if isinstance(val, Batch): return Batch.stack(stack, axis=indice.ndim) else: return np.stack(stack, axis=indice.ndim) except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() return Batch()
def test_priortized_replaybuffer(size=32, bufsize=15): env = MyTestEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) batch = Batch(obs=obs, act=a, rew=rew, done=done, obs_next=obs_next, info=info, policy=np.random.randn() - 0.5) batch_stack = Batch.stack([batch, batch, batch]) buf.add(Batch.stack([batch]), buffer_ids=[0]) buf2.add(batch_stack, buffer_ids=[0, 1, 2]) obs = obs_next data, indices = buf.sample(len(buf) // 2) if len(buf) // 2 == 0: assert len(data) == len(buf) else: assert len(data) == len(buf) // 2 assert len(buf) == min(bufsize, i + 1) assert len(buf2) == min(bufsize, 3 * (i + 1)) # check single buffer's data assert buf.info.key.shape == (buf.maxsize, ) assert buf.rew.dtype == float assert buf.done.dtype == bool data, indices = buf.sample(len(buf) // 2) buf.update_weight(indices, -data.weight / 2) assert np.allclose(buf.weight[indices], np.abs(-data.weight / 2)**buf._alpha) # check multi buffer's data assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1) batch, indices = buf2.sample(10) buf2.update_weight(indices, batch.weight * 0) weight = buf2[np.arange(buf2.maxsize)].weight mask = np.isin(np.arange(buf2.maxsize), indices) assert np.all(weight[mask] == weight[mask][0]) assert np.all(weight[~mask] == weight[~mask][0]) assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1
def test_batch_cat_and_stack(): b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b_cat_out = Batch.cat((b1, b2)) b_cat_in = copy.deepcopy(b1) b_cat_in.cat_(b2) assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e) assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e) assert isinstance(b_cat_in.a.d.e, np.ndarray) assert b_cat_in.a.d.e.ndim == 1 b_stack = Batch.stack((b1, b2)) assert isinstance(b_stack.a.d.e, np.ndarray) assert b_stack.a.d.e.ndim == 2
def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]: """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is indice. The stack_num (here equals to 4) is given from buffer initialization procedure. """ if stack_num is None: stack_num = self.stack_num if stack_num == 1: # the most often case if key != 'obs_next' or self._save_s_: val = self._meta.__dict__[key] try: return val[indice] except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() return Batch() #计算实际数量 indice = self._indices[:self._size][indice] done = self._meta.__dict__['done'] #如果查询obs_next但是并未保存就视情况+1从obs中返回 if key == 'obs_next' and not self._save_s_: indice += 1 - done[indice].astype(np.int) indice[indice == self._size] = 0 key = 'obs' val = self._meta.__dict__[key] try: if stack_num == 1: return val[indice] stack = [] for _ in range(stack_num): stack = [val[indice]] + stack #切片前移 pre_indice = np.asarray(indice - 1) #如果前移到下标为-1越界则回到最后 pre_indice[pre_indice == -1] = self._size - 1 #如果前移之后到达终止状态则回退前移步骤 indice = np.asarray( pre_indice + done[pre_indice].astype(np.int)) #如果越界则再回到开头 indice[indice == self._size] = 0 if isinstance(val, Batch): stack = Batch.stack(stack, axis=indice.ndim) else: stack = np.stack(stack, axis=indice.ndim) return stack except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() return Batch()
def get( self, indice: Union[slice, int, np.integer, np.ndarray], key: str, stack_num: Optional[int] = None, ) -> Union[Batch, np.ndarray]: """Return the stacked result. E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the indice. The stack_num (here equals to 4) is given from buffer initialization procedure. """ if stack_num is None: stack_num = self.stack_num if stack_num == 1: # the most often case if key != "obs_next" or self._save_s_: val = self._meta.__dict__[key] try: return val[indice] except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() return Batch() indice = self._indices[:self._size][indice] done = self._meta.__dict__["done"] if key == "obs_next" and not self._save_s_: indice += 1 - done[indice].astype(np.int) indice[indice == self._size] = 0 key = "obs" val = self._meta.__dict__[key] try: if stack_num == 1: return val[indice] stack: List[Any] = [] for _ in range(stack_num): stack = [val[indice]] + stack pre_indice = np.asarray(indice - 1) pre_indice[pre_indice == -1] = self._size - 1 indice = np.asarray( pre_indice + done[pre_indice].astype(np.int)) indice[indice == self._size] = 0 if isinstance(val, Batch): return Batch.stack(stack, axis=indice.ndim) else: return np.stack(stack, axis=indice.ndim) except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() return Batch()
def get( self, index: Union[int, List[int], np.ndarray], key: str, default_value: Any = None, stack_num: Optional[int] = None, ) -> Union[Batch, np.ndarray]: """Return the stacked result. E.g., if you set ``key = "obs", stack_num = 4, index = t``, it returns the stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``. :param index: the index for getting stacked data. :param str key: the key to get, should be one of the reserved_keys. :param default_value: if the given key's data is not found and default_value is set, return this default_value. :param int stack_num: Default to self.stack_num. """ if key not in self._meta and default_value is not None: return default_value val = self._meta[key] if stack_num is None: stack_num = self.stack_num try: if stack_num == 1: # the most often case return val[index] stack: List[Any] = [] if isinstance(index, list): indices = np.array(index) else: indices = index # type: ignore for _ in range(stack_num): stack = [val[indices]] + stack indices = self.prev(indices) if isinstance(val, Batch): return Batch.stack(stack, axis=indices.ndim) else: return np.stack(stack, axis=indices.ndim) except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() return Batch()
def test_batch_cat_and_stack(): # test cat with compatible keys b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b12_cat_out = Batch.cat([b1, b2]) b12_cat_in = copy.deepcopy(b1) b12_cat_in.cat_(b2) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert isinstance(b12_cat_in.a.d.e, np.ndarray) assert b12_cat_in.a.d.e.ndim == 1 a = Batch(a=Batch(a=np.random.randn(3, 4))) assert np.allclose( np.concatenate([a.a.a, a.a.a]), Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a) # test cat with lens infer a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4)) b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4)) ans = Batch.cat([a, b, a]) assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a])) assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b])) assert ans.a.t.is_empty() assert b1.stack_([b2]) is None assert isinstance(b1.a.d.e, np.ndarray) assert b1.a.d.e.ndim == 2 # test cat with incompatible keys b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), b=torch.cat([torch.zeros(3, 3), b2.b]), common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test cat with reserved keys (values are Batch()) b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), b=torch.cat([torch.zeros(3, 3), b2.b]), common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test cat with all reserved keys (values are Batch()) b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) ans = Batch(a=Batch(), b=torch.cat([torch.zeros(3, 3), b2.b]), common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) assert ans.a.is_empty() assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test stack with compatible keys b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]])) b34_stack = Batch.stack((b3, b4), axis=1) assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d)))) b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}}, {'a': True, 'b': {'c': 3.0}}]) b5 = Batch(b5_dict) assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True) assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[1] == 0.0 # test stack with incompatible keys a = Batch(a=1, b=2, c=3) b = Batch(a=4, b=5, d=6) c = Batch(c=7, b=6, d=9) d = Batch.stack([a, b, c]) assert np.allclose(d.a, [1, 4, 0]) assert np.allclose(d.b, [2, 5, 6]) assert np.allclose(d.c, [3, 0, 7]) assert np.allclose(d.d, [0, 6, 9]) # test stack with empty Batch() assert Batch.stack([Batch(), Batch(), Batch()]).is_empty() a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch()) b = Batch(a=4, b=5, d=6, e=Batch()) c = Batch(c=7, b=6, d=9, e=Batch()) d = Batch.stack([a, b, c]) assert np.allclose(d.a, [1, 4, 0]) assert np.allclose(d.b, [2, 5, 6]) assert np.allclose(d.c, [3, 0, 7]) assert np.allclose(d.d, [0, 6, 9]) assert d.e.is_empty() b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5))) b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5))) test = Batch.stack([b1, b2], axis=-1) assert test.a.is_empty() assert test.b.is_empty() assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1)) b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5))) b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5))) test = Batch.stack([b1, b2]) ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]), b=torch.stack([torch.zeros(4, 6), b2.b]), common=Batch(c=np.stack([b1.common.c, b2.common.c]))) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test with illegal input format with pytest.raises(ValueError): Batch.cat([[Batch(a=1)], [Batch(a=1)]]) with pytest.raises(ValueError): Batch.stack([[Batch(a=1)], [Batch(a=1)]]) # exceptions assert Batch.cat([]).is_empty() assert Batch.stack([]).is_empty() b1 = Batch(e=[4, 5], d=6) b2 = Batch(e=[4, 6]) with pytest.raises(ValueError): Batch.cat([b1, b2]) with pytest.raises(ValueError): Batch.stack([b1, b2], axis=1)
def test_batch(): assert list(Batch()) == [] assert Batch().is_empty() assert not Batch(b={'c': {}}).is_empty() assert Batch(b={'c': {}}).is_empty(recurse=True) assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty() assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) assert not Batch(d=1).is_empty() assert not Batch(a=np.float64(1.0)).is_empty() assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None]) assert b.c.dtype == object b = Batch(d=[None], e=[starmap], f=Batch) assert b.d.dtype == b.e.dtype == object and b.f == Batch b = Batch() b.update() assert b.is_empty() b.update(c=[3, 5]) assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({'a': 2}, a=3) assert 'a' in b and b.a == 3 assert b.pop('a') == 3 assert 'a' not in b with pytest.raises(AssertionError): Batch({1: 2}) assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object with pytest.raises(TypeError): Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch.cat_(batch) assert torch.allclose(batch.a, torch.ones(4, 3)) Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] batch.cat_(batch) assert np.allclose(batch.obs, [1, 1]) assert batch.np.shape == (6, 4) assert np.allclose(batch[0].obs, batch[1].obs) batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, shuffle=False)): if i != 5: assert b.obs == batch[i].obs else: with pytest.raises(AttributeError): batch[i].obs with pytest.raises(AttributeError): b.obs print(batch) batch = Batch(a=np.arange(10)) with pytest.raises(AssertionError): list(batch.split(0)) data = [ (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]), (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]), (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]), (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), ] for size, merge_last, result in data: bs = list(batch.split(size, shuffle=False, merge_last=merge_last)) assert [bs[i].a.tolist() for i in range(len(bs))] == result batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) assert batch_item.a.b == batch_dict['b'] assert isinstance(batch_item.a.c, float) assert batch_item.a.c == batch_dict['c'] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict['d'] batch2 = Batch(a=[{ 'b': np.float64(1.0), 'c': np.zeros(1), 'd': Batch(e=np.array(3.0))}]) assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] assert Batch(a=set((1, 2, 1))).shape == [] assert batch2.shape[0] == 1 assert 'a' in batch2 and all([i in batch2.a for i in 'bcd']) with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): batch2[1] assert batch2[0].shape == [] with pytest.raises(IndexError): batch2[0][0] with pytest.raises(TypeError): len(batch2[0]) assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.b, np.float64) assert isinstance(batch2[0].a.d.e, np.float64) batch2_from_list = Batch(list(batch2)) batch2_from_comp = Batch([e for e in batch2]) assert batch2_from_list.a.b == batch2.a.b assert batch2_from_list.a.c == batch2.a.c assert batch2_from_list.a.d.e == batch2.a.d.e assert batch2_from_comp.a.b == batch2.a.b assert batch2_from_comp.a.c == batch2.a.c assert batch2_from_comp.a.d.e == batch2.a.d.e for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]: assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e batch2.a.d.f = {} batch2_sum = (batch2 + 1.0) * 2 assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 assert batch2_sum.a.d.f.is_empty() with pytest.raises(TypeError): batch2 += [1] batch3 = Batch(a={ 'c': np.zeros(1), 'd': Batch(e=np.array([0.0]), f=np.array([3.0]))}) batch3.a.d[0] = {'e': 4.0} assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) assert batch3.a.d.f[0] == 5.0 with pytest.raises(ValueError): batch3.a.d[0] = Batch(f=5.0, g=0.0) with pytest.raises(ValueError): batch3[0] = Batch(a={"c": 2, "e": 1}) # auto convert batch4 = Batch(a=np.array(['a', 'b'])) assert batch4.a.dtype == object # auto convert to object batch4.update(a=np.array(['c', 'd'])) assert list(batch4.a) == ['c', 'd'] assert batch4.a.dtype == object # auto convert to object batch5 = Batch(a=np.array([{'index': 0}])) assert isinstance(batch5.a, Batch) assert np.allclose(batch5.a.index, [0]) batch5.b = np.array([{'index': 1}]) assert isinstance(batch5.b, Batch) assert np.allclose(batch5.b.index, [1]) # None is a valid object and can be stored in Batch a = Batch.stack([Batch(a=None), Batch(b=None)]) assert a.a[0] is None and a.a[1] is None assert a.b[0] is None and a.b[1] is None # nx.Graph corner case assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object g1 = nx.Graph() g1.add_nodes_from(list(range(10))) g2 = nx.Graph() g2.add_nodes_from(list(range(20))) assert Batch(a=np.array([g1, g2])).a.dtype == object
def test_stack(data): """Test stack""" for i in range(10000): Batch.stack((data['batch0'], data['batch0'])) data['batchs2'][i].stack_([data['batch0']])
def test_multibuf_hdf5(): size = 100 buffers = { "vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]), "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size) } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = 'cuda' if torch.cuda.is_available() else 'cpu' info_t = torch.tensor([1.]).to(device) for i in range(4): kwargs = { 'obs': Batch(index=np.array([i])), 'act': i, 'rew': np.array([1, 2]), 'done': i % 3 == 2, 'info': { "number": { "n": i, "t": info_t }, 'extra': None }, } buffers["vector"].add(**Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2]) buffers["cached"].add(**Batch.stack([kwargs, kwargs, kwargs]), cached_buffer_ids=[0, 1, 2]) # save paths = {} for k, buf in buffers.items(): f, path = tempfile.mkstemp(suffix='.hdf5') os.close(f) buf.save_hdf5(path) paths[k] = path # load replay buffer _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()} # compare for k in buffers.keys(): assert len(_buffers[k]) == len(buffers[k]) assert np.allclose(_buffers[k].act, buffers[k].act) assert _buffers[k].stack_num == buffers[k].stack_num assert _buffers[k].maxsize == buffers[k].maxsize assert np.all(_buffers[k]._indices == buffers[k]._indices) # check shallow copy in ReplayBufferManager for k in ["vector", "cached"]: buffers[k].info.number.n[0] = -100 assert buffers[k].buffers[0].info.number.n[0] == -100 # check if still behave normally for k in ["vector", "cached"]: kwargs = { 'obs': Batch(index=np.array([5])), 'act': 5, 'rew': np.array([2, 1]), 'done': False, 'info': { "number": { "n": i }, 'Timelimit.truncate': True }, } buffers[k].add(**Batch.stack([kwargs, kwargs, kwargs, kwargs])) act = np.zeros(buffers[k].maxsize) if k == "vector": act[np.arange(5)] = np.array([0, 1, 2, 3, 5]) act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5]) act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5]) act[size * 3] = 5 elif k == "cached": act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]) act[np.arange(3) + size] = np.array([3, 5, 2]) act[np.arange(3) + size * 2] = np.array([3, 5, 2]) act[np.arange(3) + size * 3] = np.array([3, 5, 2]) act[size * 4] = 5 assert np.allclose(buffers[k].act, act) for path in paths.values(): os.remove(path)
def test_batch(): assert list(Batch()) == [] assert Batch().is_empty() assert not Batch(b={'c': {}}).is_empty() assert Batch(b={'c': {}}).is_empty(recurse=True) assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty() assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) assert not Batch(d=1).is_empty() assert not Batch(a=np.float64(1.0)).is_empty() assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() b = Batch() b.update() assert b.is_empty() b.update(c=[3, 5]) assert np.allclose(b.c, [3, 5]) # mimic the behavior of dict.update, where kwargs can overwrite keys b.update({'a': 2}, a=3) assert b.a == 3 with pytest.raises(AssertionError): Batch({1: 2}) with pytest.raises(TypeError): Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] batch.cat_(batch) assert np.allclose(batch.obs, [1, 1]) assert batch.np.shape == (6, 4) assert np.allclose(batch[0].obs, batch[1].obs) batch.obs = np.arange(5) for i, b in enumerate(batch.split(1, shuffle=False)): if i != 5: assert b.obs == batch[i].obs else: with pytest.raises(AttributeError): batch[i].obs with pytest.raises(AttributeError): b.obs print(batch) batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} batch_item = Batch({'a': [batch_dict]})[0] assert isinstance(batch_item.a.b, np.ndarray) assert batch_item.a.b == batch_dict['b'] assert isinstance(batch_item.a.c, float) assert batch_item.a.c == batch_dict['c'] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict['d'] batch2 = Batch(a=[{ 'b': np.float64(1.0), 'c': np.zeros(1), 'd': Batch(e=np.array(3.0)) }]) assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] assert batch2.shape[0] == 1 with pytest.raises(IndexError): batch2[-2] with pytest.raises(IndexError): batch2[1] assert batch2[0].shape == [] with pytest.raises(IndexError): batch2[0][0] with pytest.raises(TypeError): len(batch2[0]) assert isinstance(batch2[0].a.c, np.ndarray) assert isinstance(batch2[0].a.b, np.float64) assert isinstance(batch2[0].a.d.e, np.float64) batch2_from_list = Batch(list(batch2)) batch2_from_comp = Batch([e for e in batch2]) assert batch2_from_list.a.b == batch2.a.b assert batch2_from_list.a.c == batch2.a.c assert batch2_from_list.a.d.e == batch2.a.d.e assert batch2_from_comp.a.b == batch2.a.b assert batch2_from_comp.a.c == batch2.a.c assert batch2_from_comp.a.d.e == batch2.a.d.e for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]: assert batch_slice.a.b == batch2.a.b assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e batch2_sum = (batch2 + 1.0) * 2 assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 batch3 = Batch(a={ 'c': np.zeros(1), 'd': Batch(e=np.array([0.0]), f=np.array([3.0])) }) batch3.a.d[0] = {'e': 4.0} assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) assert batch3.a.d.f[0] == 5.0 with pytest.raises(KeyError): batch3.a.d[0] = Batch(f=5.0, g=0.0) # auto convert batch4 = Batch(a=np.array(['a', 'b'])) assert batch4.a.dtype == np.object # auto convert to np.object batch4.update(a=np.array(['c', 'd'])) assert list(batch4.a) == ['c', 'd'] assert batch4.a.dtype == np.object # auto convert to np.object batch5 = Batch(a=np.array([{'index': 0}])) assert isinstance(batch5.a, Batch) assert np.allclose(batch5.a.index, [0]) batch5.b = np.array([{'index': 1}]) assert isinstance(batch5.b, Batch) assert np.allclose(batch5.b.index, [1]) # None is a valid object and can be stored in Batch a = Batch.stack([Batch(a=None), Batch(b=None)]) assert a.a[0] is None and a.a[1] is None assert a.b[0] is None and a.b[1] is None
def test_replaybuffer(size=10, bufsize=20): env = MyTestEnv(size) buf = ReplayBuffer(bufsize) buf.update(buf) assert str(buf) == buf.__class__.__name__ + '()' obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) buf.add( Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info)) obs = obs_next assert len(buf) == min(bufsize, i + 1) assert buf.act.dtype == int assert buf.act.shape == (bufsize, 1) data, indices = buf.sample(bufsize * 2) assert (indices < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() b = ReplayBuffer(size=10) # neg bsz should return empty index assert b.sample_indices(-1).tolist() == [] ptr, ep_rew, ep_len, ep_idx = b.add( Batch(obs=1, act=1, rew=1, done=1, obs_next='str', info={ 'a': 3, 'b': { 'c': 5.0 } })) assert b.obs[0] == 1 assert b.done[0] assert b.obs_next[0] == 'str' assert np.all(b.obs[1:] == 0) assert np.all(b.obs_next[1:] == np.array(None)) assert b.info.a[0] == 3 and b.info.a.dtype == int assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float assert np.all(b.info.b.c[1:] == 0.0) assert ptr.shape == (1, ) and ptr[0] == 0 assert ep_rew.shape == (1, ) and ep_rew[0] == 1 assert ep_len.shape == (1, ) and ep_len[0] == 1 assert ep_idx.shape == (1, ) and ep_idx[0] == 0 # test extra keys pop up, the buffer should handle it dynamically batch = Batch(obs=2, act=2, rew=2, done=0, obs_next="str2", info={ "a": 4, "d": { "e": -np.inf } }) b.add(batch) info_keys = ["a", "b", "d"] assert set(b.info.keys()) == set(info_keys) assert b.info.a[1] == 4 and b.info.b.c[1] == 0 assert b.info.d.e[1] == -np.inf # test batch-style adding method, where len(batch) == 1 batch.done = [1] batch.info.e = np.zeros([1, 4]) batch = Batch.stack([batch]) ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) assert ptr.shape == (1, ) and ptr[0] == 2 assert ep_rew.shape == (1, ) and ep_rew[0] == 4 assert ep_len.shape == (1, ) and ep_len[0] == 2 assert ep_idx.shape == (1, ) and ep_idx[0] == 1 assert set(b.info.keys()) == set(info_keys + ["e"]) assert b.info.e.shape == (b.maxsize, 1, 4) with pytest.raises(IndexError): b[22] # test prev / next assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1]) assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2]) batch.done = [0] b.add(batch, buffer_ids=[0]) assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])