Exemple #1
0
def test_batch_cat_and_stack():
    b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
    b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
    b12_cat_out = Batch.cat((b1, b2))
    b12_cat_in = copy.deepcopy(b1)
    b12_cat_in.cat_(b2)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert isinstance(b12_cat_in.a.d.e, np.ndarray)
    assert b12_cat_in.a.d.e.ndim == 1
    b12_stack = Batch.stack((b1, b2))
    assert isinstance(b12_stack.a.d.e, np.ndarray)
    assert b12_stack.a.d.e.ndim == 2
    b3 = Batch(a=np.zeros((3, 4)),
               b=torch.ones((2, 5)),
               c=Batch(d=[[1], [2]]))
    b4 = Batch(a=np.ones((3, 4)),
               b=torch.ones((2, 5)),
               c=Batch(d=[[0], [3]]))
    b34_stack = Batch.stack((b3, b4), axis=1)
    assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
    assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
    b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
                        {'a': True, 'b': {'c': 3.0}}])
    b5 = Batch(b5_dict)
    assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True)
    assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
    assert b5.b.d[0] == b5_dict[0]['b']['d']
    assert b5.b.d[1] == 0.0
Exemple #2
0
def test_batch_cat_and_stack_and_empty():
    b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
    b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
    b12_cat_out = Batch.cat((b1, b2))
    b12_cat_in = copy.deepcopy(b1)
    b12_cat_in.cat_(b2)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert isinstance(b12_cat_in.a.d.e, np.ndarray)
    assert b12_cat_in.a.d.e.ndim == 1
    b12_stack = Batch.stack((b1, b2))
    assert isinstance(b12_stack.a.d.e, np.ndarray)
    assert b12_stack.a.d.e.ndim == 2
    b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]]))
    b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]]))
    b34_stack = Batch.stack((b3, b4), axis=1)
    assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
    assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
    b5_dict = np.array([{
        'a': False,
        'b': {
            'c': 2.0,
            'd': 1.0
        }
    }, {
        'a': True,
        'b': {
            'c': 3.0
        }
    }])
    b5 = Batch(b5_dict)
    assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True)
    assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
    assert b5.b.d[0] == b5_dict[0]['b']['d']
    assert b5.b.d[1] == 0.0
    b5[1] = Batch.empty(b5[0])
    assert np.allclose(b5.a, [False, False])
    assert np.allclose(b5.b.c, [2, 0])
    assert np.allclose(b5.b.d, [1, 0])
    data = Batch(a=[False, True],
                 b={
                     'c': [2., 'st'],
                     'd': [1, None],
                     'e': [2., float('nan')]
                 },
                 c=np.array([1, 3, 4], dtype=np.int),
                 t=torch.tensor([4, 5, 6, 7.]))
    data[-1] = Batch.empty(data[1])
    assert np.allclose(data.c, [1, 3, 0])
    assert np.allclose(data.a, [False, False])
    assert list(data.b.c) == ['2.0', '']
    assert list(data.b.d) == [1, None]
    assert np.allclose(data.b.e, [2, 0])
    assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
    b0 = Batch()
    b0.empty_()
    assert b0.shape == []
Exemple #3
0
    def get(
        self,
        index: Union[int, np.integer, np.ndarray],
        key: str,
        stack_num: Optional[int] = None,
    ) -> Union[Batch, np.ndarray]:
        """Return the stacked result.

        E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the
        index.
        """
        if stack_num is None:
            stack_num = self.stack_num
        val = self._meta[key]
        try:
            if stack_num == 1:  # the most often case
                return val[index]
            stack: List[Any] = []
            indice = np.asarray(index)
            for _ in range(stack_num):
                stack = [val[indice]] + stack
                indice = self.prev(indice)
            if isinstance(val, Batch):
                return Batch.stack(stack, axis=indice.ndim)
            else:
                return np.stack(stack, axis=indice.ndim)
        except IndexError as e:
            if not (isinstance(val, Batch) and val.is_empty()):
                raise e  # val != Batch()
            return Batch()
Exemple #4
0
def test_priortized_replaybuffer(size=32, bufsize=15):
    env = MyTestEnv(size)
    buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
    buf2 = PrioritizedVectorReplayBuffer(bufsize,
                                         buffer_num=3,
                                         alpha=0.5,
                                         beta=0.5)
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        batch = Batch(obs=obs,
                      act=a,
                      rew=rew,
                      done=done,
                      obs_next=obs_next,
                      info=info,
                      policy=np.random.randn() - 0.5)
        batch_stack = Batch.stack([batch, batch, batch])
        buf.add(Batch.stack([batch]), buffer_ids=[0])
        buf2.add(batch_stack, buffer_ids=[0, 1, 2])
        obs = obs_next
        data, indices = buf.sample(len(buf) // 2)
        if len(buf) // 2 == 0:
            assert len(data) == len(buf)
        else:
            assert len(data) == len(buf) // 2
        assert len(buf) == min(bufsize, i + 1)
        assert len(buf2) == min(bufsize, 3 * (i + 1))
    # check single buffer's data
    assert buf.info.key.shape == (buf.maxsize, )
    assert buf.rew.dtype == float
    assert buf.done.dtype == bool
    data, indices = buf.sample(len(buf) // 2)
    buf.update_weight(indices, -data.weight / 2)
    assert np.allclose(buf.weight[indices],
                       np.abs(-data.weight / 2)**buf._alpha)
    # check multi buffer's data
    assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1)
    batch, indices = buf2.sample(10)
    buf2.update_weight(indices, batch.weight * 0)
    weight = buf2[np.arange(buf2.maxsize)].weight
    mask = np.isin(np.arange(buf2.maxsize), indices)
    assert np.all(weight[mask] == weight[mask][0])
    assert np.all(weight[~mask] == weight[~mask][0])
    assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1
Exemple #5
0
def test_batch_cat_and_stack():
    b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
    b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
    b_cat_out = Batch.cat((b1, b2))
    b_cat_in = copy.deepcopy(b1)
    b_cat_in.cat_(b2)
    assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
    assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
    assert isinstance(b_cat_in.a.d.e, np.ndarray)
    assert b_cat_in.a.d.e.ndim == 1
    b_stack = Batch.stack((b1, b2))
    assert isinstance(b_stack.a.d.e, np.ndarray)
    assert b_stack.a.d.e.ndim == 2
Exemple #6
0
 def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
         stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]:
     """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
     where s is self.key, t is indice. The stack_num (here equals to 4) is
     given from buffer initialization procedure.
     """
     if stack_num is None:
         stack_num = self.stack_num
     if stack_num == 1:  # the most often case
         if key != 'obs_next' or self._save_s_:
             val = self._meta.__dict__[key]
             try:
                 return val[indice]
             except IndexError as e:
                 if not (isinstance(val, Batch) and val.is_empty()):
                     raise e  # val != Batch()
                 return Batch()
     #计算实际数量
     indice = self._indices[:self._size][indice]
     done = self._meta.__dict__['done']
     #如果查询obs_next但是并未保存就视情况+1从obs中返回
     if key == 'obs_next' and not self._save_s_:
         indice += 1 - done[indice].astype(np.int)
         indice[indice == self._size] = 0
         key = 'obs'
     val = self._meta.__dict__[key]
     try:
         if stack_num == 1:
             return val[indice]
         stack = []
         for _ in range(stack_num):
             stack = [val[indice]] + stack
             #切片前移
             pre_indice = np.asarray(indice - 1)
             #如果前移到下标为-1越界则回到最后
             pre_indice[pre_indice == -1] = self._size - 1
             #如果前移之后到达终止状态则回退前移步骤
             indice = np.asarray(
                 pre_indice + done[pre_indice].astype(np.int))
             #如果越界则再回到开头
             indice[indice == self._size] = 0
         if isinstance(val, Batch):
             stack = Batch.stack(stack, axis=indice.ndim)
         else:
             stack = np.stack(stack, axis=indice.ndim)
         return stack
     except IndexError as e:
         if not (isinstance(val, Batch) and val.is_empty()):
             raise e  # val != Batch()
         return Batch()
Exemple #7
0
    def get(
        self,
        indice: Union[slice, int, np.integer, np.ndarray],
        key: str,
        stack_num: Optional[int] = None,
    ) -> Union[Batch, np.ndarray]:
        """Return the stacked result.

        E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the
        indice. The stack_num (here equals to 4) is given from buffer
        initialization procedure.
        """
        if stack_num is None:
            stack_num = self.stack_num
        if stack_num == 1:  # the most often case
            if key != "obs_next" or self._save_s_:
                val = self._meta.__dict__[key]
                try:
                    return val[indice]
                except IndexError as e:
                    if not (isinstance(val, Batch) and val.is_empty()):
                        raise e  # val != Batch()
                    return Batch()
        indice = self._indices[:self._size][indice]
        done = self._meta.__dict__["done"]
        if key == "obs_next" and not self._save_s_:
            indice += 1 - done[indice].astype(np.int)
            indice[indice == self._size] = 0
            key = "obs"
        val = self._meta.__dict__[key]
        try:
            if stack_num == 1:
                return val[indice]
            stack: List[Any] = []
            for _ in range(stack_num):
                stack = [val[indice]] + stack
                pre_indice = np.asarray(indice - 1)
                pre_indice[pre_indice == -1] = self._size - 1
                indice = np.asarray(
                    pre_indice + done[pre_indice].astype(np.int))
                indice[indice == self._size] = 0
            if isinstance(val, Batch):
                return Batch.stack(stack, axis=indice.ndim)
            else:
                return np.stack(stack, axis=indice.ndim)
        except IndexError as e:
            if not (isinstance(val, Batch) and val.is_empty()):
                raise e  # val != Batch()
            return Batch()
Exemple #8
0
    def get(
        self,
        index: Union[int, List[int], np.ndarray],
        key: str,
        default_value: Any = None,
        stack_num: Optional[int] = None,
    ) -> Union[Batch, np.ndarray]:
        """Return the stacked result.

        E.g., if you set ``key = "obs", stack_num = 4, index = t``, it returns the
        stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``.

        :param index: the index for getting stacked data.
        :param str key: the key to get, should be one of the reserved_keys.
        :param default_value: if the given key's data is not found and default_value is
            set, return this default_value.
        :param int stack_num: Default to self.stack_num.
        """
        if key not in self._meta and default_value is not None:
            return default_value
        val = self._meta[key]
        if stack_num is None:
            stack_num = self.stack_num
        try:
            if stack_num == 1:  # the most often case
                return val[index]
            stack: List[Any] = []
            if isinstance(index, list):
                indices = np.array(index)
            else:
                indices = index  # type: ignore
            for _ in range(stack_num):
                stack = [val[indices]] + stack
                indices = self.prev(indices)
            if isinstance(val, Batch):
                return Batch.stack(stack, axis=indices.ndim)
            else:
                return np.stack(stack, axis=indices.ndim)
        except IndexError as e:
            if not (isinstance(val, Batch) and val.is_empty()):
                raise e  # val != Batch()
            return Batch()
Exemple #9
0
def test_batch_cat_and_stack():
    # test cat with compatible keys
    b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
    b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
    b12_cat_out = Batch.cat([b1, b2])
    b12_cat_in = copy.deepcopy(b1)
    b12_cat_in.cat_(b2)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert isinstance(b12_cat_in.a.d.e, np.ndarray)
    assert b12_cat_in.a.d.e.ndim == 1

    a = Batch(a=Batch(a=np.random.randn(3, 4)))
    assert np.allclose(
        np.concatenate([a.a.a, a.a.a]),
        Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)

    # test cat with lens infer
    a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
    b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
    ans = Batch.cat([a, b, a])
    assert np.allclose(ans.a.a,
                       np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
    assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
    assert ans.a.t.is_empty()

    assert b1.stack_([b2]) is None
    assert isinstance(b1.a.d.e, np.ndarray)
    assert b1.a.d.e.ndim == 2

    # test cat with incompatible keys
    b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test cat with reserved keys (values are Batch())
    b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(a=Batch(),
               b=torch.rand(4, 3),
               common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test cat with all reserved keys (values are Batch())
    b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(a=Batch(),
               b=torch.rand(4, 3),
               common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=Batch(),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert ans.a.is_empty()
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test stack with compatible keys
    b3 = Batch(a=np.zeros((3, 4)),
               b=torch.ones((2, 5)),
               c=Batch(d=[[1], [2]]))
    b4 = Batch(a=np.ones((3, 4)),
               b=torch.ones((2, 5)),
               c=Batch(d=[[0], [3]]))
    b34_stack = Batch.stack((b3, b4), axis=1)
    assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
    assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
    b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
                        {'a': True, 'b': {'c': 3.0}}])
    b5 = Batch(b5_dict)
    assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True)
    assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
    assert b5.b.d[0] == b5_dict[0]['b']['d']
    assert b5.b.d[1] == 0.0

    # test stack with incompatible keys
    a = Batch(a=1, b=2, c=3)
    b = Batch(a=4, b=5, d=6)
    c = Batch(c=7, b=6, d=9)
    d = Batch.stack([a, b, c])
    assert np.allclose(d.a, [1, 4, 0])
    assert np.allclose(d.b, [2, 5, 6])
    assert np.allclose(d.c, [3, 0, 7])
    assert np.allclose(d.d, [0, 6, 9])

    # test stack with empty Batch()
    assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
    a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
    b = Batch(a=4, b=5, d=6, e=Batch())
    c = Batch(c=7, b=6, d=9, e=Batch())
    d = Batch.stack([a, b, c])
    assert np.allclose(d.a, [1, 4, 0])
    assert np.allclose(d.b, [2, 5, 6])
    assert np.allclose(d.c, [3, 0, 7])
    assert np.allclose(d.d, [0, 6, 9])
    assert d.e.is_empty()
    b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
    b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.stack([b1, b2], axis=-1)
    assert test.a.is_empty()
    assert test.b.is_empty()
    assert np.allclose(test.common.c,
                       np.stack([b1.common.c, b2.common.c], axis=-1))

    b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
    b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.stack([b1, b2])
    ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
                b=torch.stack([torch.zeros(4, 6), b2.b]),
                common=Batch(c=np.stack([b1.common.c, b2.common.c])))
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test with illegal input format
    with pytest.raises(ValueError):
        Batch.cat([[Batch(a=1)], [Batch(a=1)]])
    with pytest.raises(ValueError):
        Batch.stack([[Batch(a=1)], [Batch(a=1)]])

    # exceptions
    assert Batch.cat([]).is_empty()
    assert Batch.stack([]).is_empty()
    b1 = Batch(e=[4, 5], d=6)
    b2 = Batch(e=[4, 6])
    with pytest.raises(ValueError):
        Batch.cat([b1, b2])
    with pytest.raises(ValueError):
        Batch.stack([b1, b2], axis=1)
Exemple #10
0
def test_batch():
    assert list(Batch()) == []
    assert Batch().is_empty()
    assert not Batch(b={'c': {}}).is_empty()
    assert Batch(b={'c': {}}).is_empty(recurse=True)
    assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
    assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
    assert not Batch(d=1).is_empty()
    assert not Batch(a=np.float64(1.0)).is_empty()
    assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
    assert not Batch(a=[1, 2, 3]).is_empty()
    b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None])
    assert b.c.dtype == object
    b = Batch(d=[None], e=[starmap], f=Batch)
    assert b.d.dtype == b.e.dtype == object and b.f == Batch
    b = Batch()
    b.update()
    assert b.is_empty()
    b.update(c=[3, 5])
    assert np.allclose(b.c, [3, 5])
    # mimic the behavior of dict.update, where kwargs can overwrite keys
    b.update({'a': 2}, a=3)
    assert 'a' in b and b.a == 3
    assert b.pop('a') == 3
    assert 'a' not in b
    with pytest.raises(AssertionError):
        Batch({1: 2})
    assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object
    with pytest.raises(TypeError):
        Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
    batch = Batch(a=[torch.ones(3), torch.ones(3)])
    assert torch.allclose(batch.a, torch.ones(2, 3))
    batch.cat_(batch)
    assert torch.allclose(batch.a, torch.ones(4, 3))
    Batch(a=[])
    batch = Batch(obs=[0], np=np.zeros([3, 4]))
    assert batch.obs == batch["obs"]
    batch.obs = [1]
    assert batch.obs == [1]
    batch.cat_(batch)
    assert np.allclose(batch.obs, [1, 1])
    assert batch.np.shape == (6, 4)
    assert np.allclose(batch[0].obs, batch[1].obs)
    batch.obs = np.arange(5)
    for i, b in enumerate(batch.split(1, shuffle=False)):
        if i != 5:
            assert b.obs == batch[i].obs
        else:
            with pytest.raises(AttributeError):
                batch[i].obs
            with pytest.raises(AttributeError):
                b.obs
    print(batch)
    batch = Batch(a=np.arange(10))
    with pytest.raises(AssertionError):
        list(batch.split(0))
    data = [
        (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
        (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]),
        (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]),
        (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]),
        (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
        (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
        (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]),
        (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
        (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
    ]
    for size, merge_last, result in data:
        bs = list(batch.split(size, shuffle=False, merge_last=merge_last))
        assert [bs[i].a.tolist() for i in range(len(bs))] == result
    batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])}
    batch_item = Batch({'a': [batch_dict]})[0]
    assert isinstance(batch_item.a.b, np.ndarray)
    assert batch_item.a.b == batch_dict['b']
    assert isinstance(batch_item.a.c, float)
    assert batch_item.a.c == batch_dict['c']
    assert isinstance(batch_item.a.d, torch.Tensor)
    assert batch_item.a.d == batch_dict['d']
    batch2 = Batch(a=[{
        'b': np.float64(1.0),
        'c': np.zeros(1),
        'd': Batch(e=np.array(3.0))}])
    assert len(batch2) == 1
    assert Batch().shape == []
    assert Batch(a=1).shape == []
    assert Batch(a=set((1, 2, 1))).shape == []
    assert batch2.shape[0] == 1
    assert 'a' in batch2 and all([i in batch2.a for i in 'bcd'])
    with pytest.raises(IndexError):
        batch2[-2]
    with pytest.raises(IndexError):
        batch2[1]
    assert batch2[0].shape == []
    with pytest.raises(IndexError):
        batch2[0][0]
    with pytest.raises(TypeError):
        len(batch2[0])
    assert isinstance(batch2[0].a.c, np.ndarray)
    assert isinstance(batch2[0].a.b, np.float64)
    assert isinstance(batch2[0].a.d.e, np.float64)
    batch2_from_list = Batch(list(batch2))
    batch2_from_comp = Batch([e for e in batch2])
    assert batch2_from_list.a.b == batch2.a.b
    assert batch2_from_list.a.c == batch2.a.c
    assert batch2_from_list.a.d.e == batch2.a.d.e
    assert batch2_from_comp.a.b == batch2.a.b
    assert batch2_from_comp.a.c == batch2.a.c
    assert batch2_from_comp.a.d.e == batch2.a.d.e
    for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
        assert batch_slice.a.b == batch2.a.b
        assert batch_slice.a.c == batch2.a.c
        assert batch_slice.a.d.e == batch2.a.d.e
    batch2.a.d.f = {}
    batch2_sum = (batch2 + 1.0) * 2
    assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
    assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
    assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
    assert batch2_sum.a.d.f.is_empty()
    with pytest.raises(TypeError):
        batch2 += [1]
    batch3 = Batch(a={
        'c': np.zeros(1),
        'd': Batch(e=np.array([0.0]), f=np.array([3.0]))})
    batch3.a.d[0] = {'e': 4.0}
    assert batch3.a.d.e[0] == 4.0
    batch3.a.d[0] = Batch(f=5.0)
    assert batch3.a.d.f[0] == 5.0
    with pytest.raises(ValueError):
        batch3.a.d[0] = Batch(f=5.0, g=0.0)
    with pytest.raises(ValueError):
        batch3[0] = Batch(a={"c": 2, "e": 1})
    # auto convert
    batch4 = Batch(a=np.array(['a', 'b']))
    assert batch4.a.dtype == object  # auto convert to object
    batch4.update(a=np.array(['c', 'd']))
    assert list(batch4.a) == ['c', 'd']
    assert batch4.a.dtype == object  # auto convert to object
    batch5 = Batch(a=np.array([{'index': 0}]))
    assert isinstance(batch5.a, Batch)
    assert np.allclose(batch5.a.index, [0])
    batch5.b = np.array([{'index': 1}])
    assert isinstance(batch5.b, Batch)
    assert np.allclose(batch5.b.index, [1])

    # None is a valid object and can be stored in Batch
    a = Batch.stack([Batch(a=None), Batch(b=None)])
    assert a.a[0] is None and a.a[1] is None
    assert a.b[0] is None and a.b[1] is None

    # nx.Graph corner case
    assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object
    g1 = nx.Graph()
    g1.add_nodes_from(list(range(10)))
    g2 = nx.Graph()
    g2.add_nodes_from(list(range(20)))
    assert Batch(a=np.array([g1, g2])).a.dtype == object
def test_stack(data):
    """Test stack"""
    for i in range(10000):
        Batch.stack((data['batch0'], data['batch0']))
        data['batchs2'][i].stack_([data['batch0']])
Exemple #12
0
def test_multibuf_hdf5():
    size = 100
    buffers = {
        "vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]),
        "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size)
    }
    buffer_types = {k: b.__class__ for k, b in buffers.items()}
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    info_t = torch.tensor([1.]).to(device)
    for i in range(4):
        kwargs = {
            'obs': Batch(index=np.array([i])),
            'act': i,
            'rew': np.array([1, 2]),
            'done': i % 3 == 2,
            'info': {
                "number": {
                    "n": i,
                    "t": info_t
                },
                'extra': None
            },
        }
        buffers["vector"].add(**Batch.stack([kwargs, kwargs, kwargs]),
                              buffer_ids=[0, 1, 2])
        buffers["cached"].add(**Batch.stack([kwargs, kwargs, kwargs]),
                              cached_buffer_ids=[0, 1, 2])

    # save
    paths = {}
    for k, buf in buffers.items():
        f, path = tempfile.mkstemp(suffix='.hdf5')
        os.close(f)
        buf.save_hdf5(path)
        paths[k] = path

    # load replay buffer
    _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}

    # compare
    for k in buffers.keys():
        assert len(_buffers[k]) == len(buffers[k])
        assert np.allclose(_buffers[k].act, buffers[k].act)
        assert _buffers[k].stack_num == buffers[k].stack_num
        assert _buffers[k].maxsize == buffers[k].maxsize
        assert np.all(_buffers[k]._indices == buffers[k]._indices)
    # check shallow copy in ReplayBufferManager
    for k in ["vector", "cached"]:
        buffers[k].info.number.n[0] = -100
        assert buffers[k].buffers[0].info.number.n[0] == -100
    # check if still behave normally
    for k in ["vector", "cached"]:
        kwargs = {
            'obs': Batch(index=np.array([5])),
            'act': 5,
            'rew': np.array([2, 1]),
            'done': False,
            'info': {
                "number": {
                    "n": i
                },
                'Timelimit.truncate': True
            },
        }
        buffers[k].add(**Batch.stack([kwargs, kwargs, kwargs, kwargs]))
        act = np.zeros(buffers[k].maxsize)
        if k == "vector":
            act[np.arange(5)] = np.array([0, 1, 2, 3, 5])
            act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5])
            act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5])
            act[size * 3] = 5
        elif k == "cached":
            act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
            act[np.arange(3) + size] = np.array([3, 5, 2])
            act[np.arange(3) + size * 2] = np.array([3, 5, 2])
            act[np.arange(3) + size * 3] = np.array([3, 5, 2])
            act[size * 4] = 5
        assert np.allclose(buffers[k].act, act)

    for path in paths.values():
        os.remove(path)
Exemple #13
0
def test_batch():
    assert list(Batch()) == []
    assert Batch().is_empty()
    assert not Batch(b={'c': {}}).is_empty()
    assert Batch(b={'c': {}}).is_empty(recurse=True)
    assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
    assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
    assert not Batch(d=1).is_empty()
    assert not Batch(a=np.float64(1.0)).is_empty()
    assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
    assert not Batch(a=[1, 2, 3]).is_empty()
    b = Batch()
    b.update()
    assert b.is_empty()
    b.update(c=[3, 5])
    assert np.allclose(b.c, [3, 5])
    # mimic the behavior of dict.update, where kwargs can overwrite keys
    b.update({'a': 2}, a=3)
    assert b.a == 3
    with pytest.raises(AssertionError):
        Batch({1: 2})
    with pytest.raises(TypeError):
        Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
    with pytest.raises(TypeError):
        Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
    batch = Batch(a=[torch.ones(3), torch.ones(3)])
    assert torch.allclose(batch.a, torch.ones(2, 3))
    Batch(a=[])
    batch = Batch(obs=[0], np=np.zeros([3, 4]))
    assert batch.obs == batch["obs"]
    batch.obs = [1]
    assert batch.obs == [1]
    batch.cat_(batch)
    assert np.allclose(batch.obs, [1, 1])
    assert batch.np.shape == (6, 4)
    assert np.allclose(batch[0].obs, batch[1].obs)
    batch.obs = np.arange(5)
    for i, b in enumerate(batch.split(1, shuffle=False)):
        if i != 5:
            assert b.obs == batch[i].obs
        else:
            with pytest.raises(AttributeError):
                batch[i].obs
            with pytest.raises(AttributeError):
                b.obs
    print(batch)
    batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])}
    batch_item = Batch({'a': [batch_dict]})[0]
    assert isinstance(batch_item.a.b, np.ndarray)
    assert batch_item.a.b == batch_dict['b']
    assert isinstance(batch_item.a.c, float)
    assert batch_item.a.c == batch_dict['c']
    assert isinstance(batch_item.a.d, torch.Tensor)
    assert batch_item.a.d == batch_dict['d']
    batch2 = Batch(a=[{
        'b': np.float64(1.0),
        'c': np.zeros(1),
        'd': Batch(e=np.array(3.0))
    }])
    assert len(batch2) == 1
    assert Batch().shape == []
    assert Batch(a=1).shape == []
    assert batch2.shape[0] == 1
    with pytest.raises(IndexError):
        batch2[-2]
    with pytest.raises(IndexError):
        batch2[1]
    assert batch2[0].shape == []
    with pytest.raises(IndexError):
        batch2[0][0]
    with pytest.raises(TypeError):
        len(batch2[0])
    assert isinstance(batch2[0].a.c, np.ndarray)
    assert isinstance(batch2[0].a.b, np.float64)
    assert isinstance(batch2[0].a.d.e, np.float64)
    batch2_from_list = Batch(list(batch2))
    batch2_from_comp = Batch([e for e in batch2])
    assert batch2_from_list.a.b == batch2.a.b
    assert batch2_from_list.a.c == batch2.a.c
    assert batch2_from_list.a.d.e == batch2.a.d.e
    assert batch2_from_comp.a.b == batch2.a.b
    assert batch2_from_comp.a.c == batch2.a.c
    assert batch2_from_comp.a.d.e == batch2.a.d.e
    for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
        assert batch_slice.a.b == batch2.a.b
        assert batch_slice.a.c == batch2.a.c
        assert batch_slice.a.d.e == batch2.a.d.e
    batch2_sum = (batch2 + 1.0) * 2
    assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
    assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
    assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
    batch3 = Batch(a={
        'c': np.zeros(1),
        'd': Batch(e=np.array([0.0]), f=np.array([3.0]))
    })
    batch3.a.d[0] = {'e': 4.0}
    assert batch3.a.d.e[0] == 4.0
    batch3.a.d[0] = Batch(f=5.0)
    assert batch3.a.d.f[0] == 5.0
    with pytest.raises(KeyError):
        batch3.a.d[0] = Batch(f=5.0, g=0.0)
    # auto convert
    batch4 = Batch(a=np.array(['a', 'b']))
    assert batch4.a.dtype == np.object  # auto convert to np.object
    batch4.update(a=np.array(['c', 'd']))
    assert list(batch4.a) == ['c', 'd']
    assert batch4.a.dtype == np.object  # auto convert to np.object
    batch5 = Batch(a=np.array([{'index': 0}]))
    assert isinstance(batch5.a, Batch)
    assert np.allclose(batch5.a.index, [0])
    batch5.b = np.array([{'index': 1}])
    assert isinstance(batch5.b, Batch)
    assert np.allclose(batch5.b.index, [1])

    # None is a valid object and can be stored in Batch
    a = Batch.stack([Batch(a=None), Batch(b=None)])
    assert a.a[0] is None and a.a[1] is None
    assert a.b[0] is None and a.b[1] is None
Exemple #14
0
def test_replaybuffer(size=10, bufsize=20):
    env = MyTestEnv(size)
    buf = ReplayBuffer(bufsize)
    buf.update(buf)
    assert str(buf) == buf.__class__.__name__ + '()'
    obs = env.reset()
    action_list = [1] * 5 + [0] * 10 + [1] * 10
    for i, a in enumerate(action_list):
        obs_next, rew, done, info = env.step(a)
        buf.add(
            Batch(obs=obs,
                  act=[a],
                  rew=rew,
                  done=done,
                  obs_next=obs_next,
                  info=info))
        obs = obs_next
        assert len(buf) == min(bufsize, i + 1)
    assert buf.act.dtype == int
    assert buf.act.shape == (bufsize, 1)
    data, indices = buf.sample(bufsize * 2)
    assert (indices < len(buf)).all()
    assert (data.obs < size).all()
    assert (0 <= data.done).all() and (data.done <= 1).all()
    b = ReplayBuffer(size=10)
    # neg bsz should return empty index
    assert b.sample_indices(-1).tolist() == []
    ptr, ep_rew, ep_len, ep_idx = b.add(
        Batch(obs=1,
              act=1,
              rew=1,
              done=1,
              obs_next='str',
              info={
                  'a': 3,
                  'b': {
                      'c': 5.0
                  }
              }))
    assert b.obs[0] == 1
    assert b.done[0]
    assert b.obs_next[0] == 'str'
    assert np.all(b.obs[1:] == 0)
    assert np.all(b.obs_next[1:] == np.array(None))
    assert b.info.a[0] == 3 and b.info.a.dtype == int
    assert np.all(b.info.a[1:] == 0)
    assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float
    assert np.all(b.info.b.c[1:] == 0.0)
    assert ptr.shape == (1, ) and ptr[0] == 0
    assert ep_rew.shape == (1, ) and ep_rew[0] == 1
    assert ep_len.shape == (1, ) and ep_len[0] == 1
    assert ep_idx.shape == (1, ) and ep_idx[0] == 0
    # test extra keys pop up, the buffer should handle it dynamically
    batch = Batch(obs=2,
                  act=2,
                  rew=2,
                  done=0,
                  obs_next="str2",
                  info={
                      "a": 4,
                      "d": {
                          "e": -np.inf
                      }
                  })
    b.add(batch)
    info_keys = ["a", "b", "d"]
    assert set(b.info.keys()) == set(info_keys)
    assert b.info.a[1] == 4 and b.info.b.c[1] == 0
    assert b.info.d.e[1] == -np.inf
    # test batch-style adding method, where len(batch) == 1
    batch.done = [1]
    batch.info.e = np.zeros([1, 4])
    batch = Batch.stack([batch])
    ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0])
    assert ptr.shape == (1, ) and ptr[0] == 2
    assert ep_rew.shape == (1, ) and ep_rew[0] == 4
    assert ep_len.shape == (1, ) and ep_len[0] == 2
    assert ep_idx.shape == (1, ) and ep_idx[0] == 1
    assert set(b.info.keys()) == set(info_keys + ["e"])
    assert b.info.e.shape == (b.maxsize, 1, 4)
    with pytest.raises(IndexError):
        b[22]
    # test prev / next
    assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1])
    assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2])
    batch.done = [0]
    b.add(batch, buffer_ids=[0])
    assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3])
    assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])