예제 #1
0
    def sample(self, batch_size: int) -> Batch:
        """Sample a data batch from the internal replay buffer. It will call
        :meth:`~tianshou.policy.BasePolicy.process_fn` before returning
        the final batch data.

        :param int batch_size: ``0`` means it will extract all the data from
            the buffer, otherwise it will extract the data with the given
            batch_size.
        """
        if self._multi_buf:
            if batch_size > 0:
                lens = [len(b) for b in self.buffer]
                total = sum(lens)
                batch_index = np.random.choice(len(self.buffer),
                                               batch_size,
                                               p=np.array(lens) / total)
            else:
                batch_index = np.array([])
            batch_data = Batch()
            for i, b in enumerate(self.buffer):
                cur_batch = (batch_index == i).sum()
                if batch_size and cur_batch or batch_size <= 0:
                    batch, indice = b.sample(cur_batch)
                    batch = self.process_fn(batch, b, indice)
                    batch_data.cat(batch)
        else:
            batch_data, indice = self.buffer.sample(batch_size)
            batch_data = self.process_fn(batch_data, self.buffer, indice)
        return batch_data
예제 #2
0
def test_batch_cat_and_stack():
    b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
    b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
    b12_cat_out = Batch.cat((b1, b2))
    b12_cat_in = copy.deepcopy(b1)
    b12_cat_in.cat_(b2)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert isinstance(b12_cat_in.a.d.e, np.ndarray)
    assert b12_cat_in.a.d.e.ndim == 1
    b12_stack = Batch.stack((b1, b2))
    assert isinstance(b12_stack.a.d.e, np.ndarray)
    assert b12_stack.a.d.e.ndim == 2
    b3 = Batch(a=np.zeros((3, 4)),
               b=torch.ones((2, 5)),
               c=Batch(d=[[1], [2]]))
    b4 = Batch(a=np.ones((3, 4)),
               b=torch.ones((2, 5)),
               c=Batch(d=[[0], [3]]))
    b34_stack = Batch.stack((b3, b4), axis=1)
    assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
    assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
    b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
                        {'a': True, 'b': {'c': 3.0}}])
    b5 = Batch(b5_dict)
    assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True)
    assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
    assert b5.b.d[0] == b5_dict[0]['b']['d']
    assert b5.b.d[1] == 0.0
예제 #3
0
def test_async_env(size=10000, num=8, sleep=0.1):
    # simplify the test case, just keep stepping
    env_fns = [
        lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
        for i in range(size, size + num)
    ]
    test_cls = [SubprocVectorEnv, ShmemVectorEnv]
    if has_ray():
        test_cls += [RayVectorEnv]
    for cls in test_cls:
        v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
        v.seed(None)
        v.reset()
        # for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un}
        # P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1}
        # expectation of v is n / (n + 1)
        # for a synchronous environment, the following actions should take
        # about 7 * sleep * num / (num + 1) seconds
        # for async simulation, the analysis is complicated, but the time cost
        # should be smaller
        action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
        current_idx_start = 0
        act = action_list[:num]
        env_ids = list(range(num))
        o = []
        spent_time = time.time()
        while current_idx_start < len(action_list):
            A, B, C, D = v.step(action=act, id=env_ids)
            b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
            env_ids = b.info.env_id
            o.append(b)
            current_idx_start += len(act)
            # len of action may be smaller than len(A) in the end
            act = action_list[current_idx_start:current_idx_start + len(A)]
            # truncate env_ids with the first terms
            # typically len(env_ids) == len(A) == len(action), except for the
            # last batch when actions are not enough
            env_ids = env_ids[:len(act)]
        spent_time = time.time() - spent_time
        Batch.cat(o)
        v.close()
        # assure 1/7 improvement
        if sys.platform == "linux" and cls != RayVectorEnv:
            # macOS/Windows cannot pass this check
            assert spent_time < 6.0 * sleep * num / (num + 1)
예제 #4
0
def test_batch_cat_and_stack_and_empty():
    b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
    b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
    b12_cat_out = Batch.cat((b1, b2))
    b12_cat_in = copy.deepcopy(b1)
    b12_cat_in.cat_(b2)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert isinstance(b12_cat_in.a.d.e, np.ndarray)
    assert b12_cat_in.a.d.e.ndim == 1
    b12_stack = Batch.stack((b1, b2))
    assert isinstance(b12_stack.a.d.e, np.ndarray)
    assert b12_stack.a.d.e.ndim == 2
    b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]]))
    b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]]))
    b34_stack = Batch.stack((b3, b4), axis=1)
    assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
    assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
    b5_dict = np.array([{
        'a': False,
        'b': {
            'c': 2.0,
            'd': 1.0
        }
    }, {
        'a': True,
        'b': {
            'c': 3.0
        }
    }])
    b5 = Batch(b5_dict)
    assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True)
    assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
    assert b5.b.d[0] == b5_dict[0]['b']['d']
    assert b5.b.d[1] == 0.0
    b5[1] = Batch.empty(b5[0])
    assert np.allclose(b5.a, [False, False])
    assert np.allclose(b5.b.c, [2, 0])
    assert np.allclose(b5.b.d, [1, 0])
    data = Batch(a=[False, True],
                 b={
                     'c': [2., 'st'],
                     'd': [1, None],
                     'e': [2., float('nan')]
                 },
                 c=np.array([1, 3, 4], dtype=np.int),
                 t=torch.tensor([4, 5, 6, 7.]))
    data[-1] = Batch.empty(data[1])
    assert np.allclose(data.c, [1, 3, 0])
    assert np.allclose(data.a, [False, False])
    assert list(data.b.c) == ['2.0', '']
    assert list(data.b.d) == [1, None]
    assert np.allclose(data.b.e, [2, 0])
    assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
    b0 = Batch()
    b0.empty_()
    assert b0.shape == []
예제 #5
0
def test_batch_cat_and_stack():
    b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
    b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
    b_cat_out = Batch.cat((b1, b2))
    b_cat_in = copy.deepcopy(b1)
    b_cat_in.cat_(b2)
    assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
    assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e)
    assert isinstance(b_cat_in.a.d.e, np.ndarray)
    assert b_cat_in.a.d.e.ndim == 1
    b_stack = Batch.stack((b1, b2))
    assert isinstance(b_stack.a.d.e, np.ndarray)
    assert b_stack.a.d.e.ndim == 2
예제 #6
0
    def forward(
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch]] = None,
        **kwargs: Any,
    ) -> Batch:
        """Dispatch batch data from obs.agent_id to every policy's forward.

        :param state: if None, it means all agents have no state. If not
            None, it should contain keys of "agent_1", "agent_2", ...

        :return: a Batch with the following contents:

        ::

            {
                "act": actions corresponding to the input
                "state": {
                    "agent_1": output state of agent_1's policy for the state
                    "agent_2": xxx
                    ...
                    "agent_n": xxx}
                "out": {
                    "agent_1": output of agent_1's policy for the input
                    "agent_2": xxx
                    ...
                    "agent_n": xxx}
            }
        """
        results = []
        for policy in self.policies:
            # This part of code is difficult to understand.
            # Let's follow an example with two agents
            # batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6)
            # each agent plays for three transitions
            # agent_index for agent 1 is [0, 2, 4]
            # agent_index for agent 2 is [1, 3, 5]
            # we separate the transition of each agent according to agent_id
            agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
            if len(agent_index) == 0:
                # (has_data, agent_index, out, act, state)
                results.append((False, None, Batch(), None, Batch()))
                continue
            tmp_batch = batch[agent_index]
            if isinstance(tmp_batch.rew, np.ndarray):
                # reward can be empty Batch (after initial reset) or nparray.
                tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
            out = policy(batch=tmp_batch, state=None if state is None
                         else state["agent_" + str(policy.agent_id)],
                         **kwargs)
            act = out.act
            each_state = out.state \
                if (hasattr(out, "state") and out.state is not None) \
                else Batch()
            results.append((True, agent_index, out, act, each_state))
        holder = Batch.cat([{"act": act} for
                            (has_data, agent_index, out, act, each_state)
                            in results if has_data])
        state_dict, out_dict = {}, {}
        for policy, (has_data, agent_index, out, act, state) in zip(
                self.policies, results):
            if has_data:
                holder.act[agent_index] = act
            state_dict["agent_" + str(policy.agent_id)] = state
            out_dict["agent_" + str(policy.agent_id)] = out
        holder["out"] = out_dict
        holder["state"] = state_dict
        return holder
예제 #7
0
def test_batch_cat_and_stack():
    # test cat with compatible keys
    b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
    b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
    b12_cat_out = Batch.cat([b1, b2])
    b12_cat_in = copy.deepcopy(b1)
    b12_cat_in.cat_(b2)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert isinstance(b12_cat_in.a.d.e, np.ndarray)
    assert b12_cat_in.a.d.e.ndim == 1

    a = Batch(a=Batch(a=np.random.randn(3, 4)))
    assert np.allclose(
        np.concatenate([a.a.a, a.a.a]),
        Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)

    # test cat with lens infer
    a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
    b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
    ans = Batch.cat([a, b, a])
    assert np.allclose(ans.a.a,
                       np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
    assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
    assert ans.a.t.is_empty()

    assert b1.stack_([b2]) is None
    assert isinstance(b1.a.d.e, np.ndarray)
    assert b1.a.d.e.ndim == 2

    # test cat with incompatible keys
    b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test cat with reserved keys (values are Batch())
    b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(a=Batch(),
               b=torch.rand(4, 3),
               common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test cat with all reserved keys (values are Batch())
    b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(a=Batch(),
               b=torch.rand(4, 3),
               common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=Batch(),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert ans.a.is_empty()
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test stack with compatible keys
    b3 = Batch(a=np.zeros((3, 4)),
               b=torch.ones((2, 5)),
               c=Batch(d=[[1], [2]]))
    b4 = Batch(a=np.ones((3, 4)),
               b=torch.ones((2, 5)),
               c=Batch(d=[[0], [3]]))
    b34_stack = Batch.stack((b3, b4), axis=1)
    assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
    assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
    b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
                        {'a': True, 'b': {'c': 3.0}}])
    b5 = Batch(b5_dict)
    assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True)
    assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
    assert b5.b.d[0] == b5_dict[0]['b']['d']
    assert b5.b.d[1] == 0.0

    # test stack with incompatible keys
    a = Batch(a=1, b=2, c=3)
    b = Batch(a=4, b=5, d=6)
    c = Batch(c=7, b=6, d=9)
    d = Batch.stack([a, b, c])
    assert np.allclose(d.a, [1, 4, 0])
    assert np.allclose(d.b, [2, 5, 6])
    assert np.allclose(d.c, [3, 0, 7])
    assert np.allclose(d.d, [0, 6, 9])

    # test stack with empty Batch()
    assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
    a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
    b = Batch(a=4, b=5, d=6, e=Batch())
    c = Batch(c=7, b=6, d=9, e=Batch())
    d = Batch.stack([a, b, c])
    assert np.allclose(d.a, [1, 4, 0])
    assert np.allclose(d.b, [2, 5, 6])
    assert np.allclose(d.c, [3, 0, 7])
    assert np.allclose(d.d, [0, 6, 9])
    assert d.e.is_empty()
    b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
    b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.stack([b1, b2], axis=-1)
    assert test.a.is_empty()
    assert test.b.is_empty()
    assert np.allclose(test.common.c,
                       np.stack([b1.common.c, b2.common.c], axis=-1))

    b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
    b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.stack([b1, b2])
    ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
                b=torch.stack([torch.zeros(4, 6), b2.b]),
                common=Batch(c=np.stack([b1.common.c, b2.common.c])))
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test with illegal input format
    with pytest.raises(ValueError):
        Batch.cat([[Batch(a=1)], [Batch(a=1)]])
    with pytest.raises(ValueError):
        Batch.stack([[Batch(a=1)], [Batch(a=1)]])

    # exceptions
    assert Batch.cat([]).is_empty()
    assert Batch.stack([]).is_empty()
    b1 = Batch(e=[4, 5], d=6)
    b2 = Batch(e=[4, 6])
    with pytest.raises(ValueError):
        Batch.cat([b1, b2])
    with pytest.raises(ValueError):
        Batch.stack([b1, b2], axis=1)
예제 #8
0
def test_cat(data):
    """Test cat"""
    for i in range(10000):
        Batch.cat((data['batch0'], data['batch0']))
        data['batchs1'][i].cat_(data['batch0'])
예제 #9
0
def test_multibuf_hdf5():
    size = 100
    buffers = {
        "vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]),
        "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size)
    }
    buffer_types = {k: b.__class__ for k, b in buffers.items()}
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    info_t = torch.tensor([1.]).to(device)
    for i in range(4):
        kwargs = {
            'obs': Batch(index=np.array([i])),
            'act': i,
            'rew': np.array([1, 2]),
            'done': i % 3 == 2,
            'info': {"number": {"n": i, "t": info_t}, 'extra': None},
        }
        buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
                              buffer_ids=[0, 1, 2])
        buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
                              cached_buffer_ids=[0, 1, 2])

    # save
    paths = {}
    for k, buf in buffers.items():
        f, path = tempfile.mkstemp(suffix='.hdf5')
        os.close(f)
        buf.save_hdf5(path)
        paths[k] = path

    # load replay buffer
    _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}

    # compare
    for k in buffers.keys():
        assert len(_buffers[k]) == len(buffers[k])
        assert np.allclose(_buffers[k].act, buffers[k].act)
        assert _buffers[k].stack_num == buffers[k].stack_num
        assert _buffers[k].maxsize == buffers[k].maxsize
        assert np.all(_buffers[k]._indices == buffers[k]._indices)
    # check shallow copy in ReplayBufferManager
    for k in ["vector", "cached"]:
        buffers[k].info.number.n[0] = -100
        assert buffers[k].buffers[0].info.number.n[0] == -100
    # check if still behave normally
    for k in ["vector", "cached"]:
        kwargs = {
            'obs': Batch(index=np.array([5])),
            'act': 5,
            'rew': np.array([2, 1]),
            'done': False,
            'info': {"number": {"n": i}, 'Timelimit.truncate': True},
        }
        buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]]))
        act = np.zeros(buffers[k].maxsize)
        if k == "vector":
            act[np.arange(5)] = np.array([0, 1, 2, 3, 5])
            act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5])
            act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5])
            act[size * 3] = 5
        elif k == "cached":
            act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
            act[np.arange(3) + size] = np.array([3, 5, 2])
            act[np.arange(3) + size * 2] = np.array([3, 5, 2])
            act[np.arange(3) + size * 3] = np.array([3, 5, 2])
            act[size * 4] = 5
        assert np.allclose(buffers[k].act, act)

    for path in paths.values():
        os.remove(path)