def sample(self, batch_size: int) -> Batch: """Sample a data batch from the internal replay buffer. It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. :param int batch_size: ``0`` means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size. """ if self._multi_buf: if batch_size > 0: lens = [len(b) for b in self.buffer] total = sum(lens) batch_index = np.random.choice(len(self.buffer), batch_size, p=np.array(lens) / total) else: batch_index = np.array([]) batch_data = Batch() for i, b in enumerate(self.buffer): cur_batch = (batch_index == i).sum() if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) batch_data.cat(batch) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data
def test_batch_cat_and_stack(): b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b12_cat_out = Batch.cat((b1, b2)) b12_cat_in = copy.deepcopy(b1) b12_cat_in.cat_(b2) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert isinstance(b12_cat_in.a.d.e, np.ndarray) assert b12_cat_in.a.d.e.ndim == 1 b12_stack = Batch.stack((b1, b2)) assert isinstance(b12_stack.a.d.e, np.ndarray) assert b12_stack.a.d.e.ndim == 2 b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]])) b34_stack = Batch.stack((b3, b4), axis=1) assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d)))) b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}}, {'a': True, 'b': {'c': 3.0}}]) b5 = Batch(b5_dict) assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True) assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[1] == 0.0
def test_async_env(size=10000, num=8, sleep=0.1): # simplify the test case, just keep stepping env_fns = [ lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True) for i in range(size, size + num) ] test_cls = [SubprocVectorEnv, ShmemVectorEnv] if has_ray(): test_cls += [RayVectorEnv] for cls in test_cls: v = cls(env_fns, wait_num=num // 2, timeout=1e-3) v.seed(None) v.reset() # for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un} # P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1} # expectation of v is n / (n + 1) # for a synchronous environment, the following actions should take # about 7 * sleep * num / (num + 1) seconds # for async simulation, the analysis is complicated, but the time cost # should be smaller action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4) current_idx_start = 0 act = action_list[:num] env_ids = list(range(num)) o = [] spent_time = time.time() while current_idx_start < len(action_list): A, B, C, D = v.step(action=act, id=env_ids) b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D}) env_ids = b.info.env_id o.append(b) current_idx_start += len(act) # len of action may be smaller than len(A) in the end act = action_list[current_idx_start:current_idx_start + len(A)] # truncate env_ids with the first terms # typically len(env_ids) == len(A) == len(action), except for the # last batch when actions are not enough env_ids = env_ids[:len(act)] spent_time = time.time() - spent_time Batch.cat(o) v.close() # assure 1/7 improvement if sys.platform == "linux" and cls != RayVectorEnv: # macOS/Windows cannot pass this check assert spent_time < 6.0 * sleep * num / (num + 1)
def test_batch_cat_and_stack_and_empty(): b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b12_cat_out = Batch.cat((b1, b2)) b12_cat_in = copy.deepcopy(b1) b12_cat_in.cat_(b2) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert isinstance(b12_cat_in.a.d.e, np.ndarray) assert b12_cat_in.a.d.e.ndim == 1 b12_stack = Batch.stack((b1, b2)) assert isinstance(b12_stack.a.d.e, np.ndarray) assert b12_stack.a.d.e.ndim == 2 b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]])) b34_stack = Batch.stack((b3, b4), axis=1) assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d)))) b5_dict = np.array([{ 'a': False, 'b': { 'c': 2.0, 'd': 1.0 } }, { 'a': True, 'b': { 'c': 3.0 } }]) b5 = Batch(b5_dict) assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True) assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[1] == 0.0 b5[1] = Batch.empty(b5[0]) assert np.allclose(b5.a, [False, False]) assert np.allclose(b5.b.c, [2, 0]) assert np.allclose(b5.b.d, [1, 0]) data = Batch(a=[False, True], b={ 'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')] }, c=np.array([1, 3, 4], dtype=np.int), t=torch.tensor([4, 5, 6, 7.])) data[-1] = Batch.empty(data[1]) assert np.allclose(data.c, [1, 3, 0]) assert np.allclose(data.a, [False, False]) assert list(data.b.c) == ['2.0', ''] assert list(data.b.d) == [1, None] assert np.allclose(data.b.e, [2, 0]) assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.])) b0 = Batch() b0.empty_() assert b0.shape == []
def test_batch_cat_and_stack(): b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b_cat_out = Batch.cat((b1, b2)) b_cat_in = copy.deepcopy(b1) b_cat_in.cat_(b2) assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e) assert np.all(b_cat_in.a.d.e == b_cat_out.a.d.e) assert isinstance(b_cat_in.a.d.e, np.ndarray) assert b_cat_in.a.d.e.ndim == 1 b_stack = Batch.stack((b1, b2)) assert isinstance(b_stack.a.d.e, np.ndarray) assert b_stack.a.d.e.ndim == 2
def forward( self, batch: Batch, state: Optional[Union[dict, Batch]] = None, **kwargs: Any, ) -> Batch: """Dispatch batch data from obs.agent_id to every policy's forward. :param state: if None, it means all agents have no state. If not None, it should contain keys of "agent_1", "agent_2", ... :return: a Batch with the following contents: :: { "act": actions corresponding to the input "state": { "agent_1": output state of agent_1's policy for the state "agent_2": xxx ... "agent_n": xxx} "out": { "agent_1": output of agent_1's policy for the input "agent_2": xxx ... "agent_n": xxx} } """ results = [] for policy in self.policies: # This part of code is difficult to understand. # Let's follow an example with two agents # batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6) # each agent plays for three transitions # agent_index for agent 1 is [0, 2, 4] # agent_index for agent 2 is [1, 3, 5] # we separate the transition of each agent according to agent_id agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] if len(agent_index) == 0: # (has_data, agent_index, out, act, state) results.append((False, None, Batch(), None, Batch())) continue tmp_batch = batch[agent_index] if isinstance(tmp_batch.rew, np.ndarray): # reward can be empty Batch (after initial reset) or nparray. tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] out = policy(batch=tmp_batch, state=None if state is None else state["agent_" + str(policy.agent_id)], **kwargs) act = out.act each_state = out.state \ if (hasattr(out, "state") and out.state is not None) \ else Batch() results.append((True, agent_index, out, act, each_state)) holder = Batch.cat([{"act": act} for (has_data, agent_index, out, act, each_state) in results if has_data]) state_dict, out_dict = {}, {} for policy, (has_data, agent_index, out, act, state) in zip( self.policies, results): if has_data: holder.act[agent_index] = act state_dict["agent_" + str(policy.agent_id)] = state out_dict["agent_" + str(policy.agent_id)] = out holder["out"] = out_dict holder["state"] = state_dict return holder
def test_batch_cat_and_stack(): # test cat with compatible keys b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b12_cat_out = Batch.cat([b1, b2]) b12_cat_in = copy.deepcopy(b1) b12_cat_in.cat_(b2) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert isinstance(b12_cat_in.a.d.e, np.ndarray) assert b12_cat_in.a.d.e.ndim == 1 a = Batch(a=Batch(a=np.random.randn(3, 4))) assert np.allclose( np.concatenate([a.a.a, a.a.a]), Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a) # test cat with lens infer a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4)) b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4)) ans = Batch.cat([a, b, a]) assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a])) assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b])) assert ans.a.t.is_empty() assert b1.stack_([b2]) is None assert isinstance(b1.a.d.e, np.ndarray) assert b1.a.d.e.ndim == 2 # test cat with incompatible keys b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), b=torch.cat([torch.zeros(3, 3), b2.b]), common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test cat with reserved keys (values are Batch()) b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), b=torch.cat([torch.zeros(3, 3), b2.b]), common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test cat with all reserved keys (values are Batch()) b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) ans = Batch(a=Batch(), b=torch.cat([torch.zeros(3, 3), b2.b]), common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) assert ans.a.is_empty() assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test stack with compatible keys b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]])) b34_stack = Batch.stack((b3, b4), axis=1) assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d)))) b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}}, {'a': True, 'b': {'c': 3.0}}]) b5 = Batch(b5_dict) assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True) assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[1] == 0.0 # test stack with incompatible keys a = Batch(a=1, b=2, c=3) b = Batch(a=4, b=5, d=6) c = Batch(c=7, b=6, d=9) d = Batch.stack([a, b, c]) assert np.allclose(d.a, [1, 4, 0]) assert np.allclose(d.b, [2, 5, 6]) assert np.allclose(d.c, [3, 0, 7]) assert np.allclose(d.d, [0, 6, 9]) # test stack with empty Batch() assert Batch.stack([Batch(), Batch(), Batch()]).is_empty() a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch()) b = Batch(a=4, b=5, d=6, e=Batch()) c = Batch(c=7, b=6, d=9, e=Batch()) d = Batch.stack([a, b, c]) assert np.allclose(d.a, [1, 4, 0]) assert np.allclose(d.b, [2, 5, 6]) assert np.allclose(d.c, [3, 0, 7]) assert np.allclose(d.d, [0, 6, 9]) assert d.e.is_empty() b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5))) b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5))) test = Batch.stack([b1, b2], axis=-1) assert test.a.is_empty() assert test.b.is_empty() assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1)) b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5))) b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5))) test = Batch.stack([b1, b2]) ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]), b=torch.stack([torch.zeros(4, 6), b2.b]), common=Batch(c=np.stack([b1.common.c, b2.common.c]))) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test with illegal input format with pytest.raises(ValueError): Batch.cat([[Batch(a=1)], [Batch(a=1)]]) with pytest.raises(ValueError): Batch.stack([[Batch(a=1)], [Batch(a=1)]]) # exceptions assert Batch.cat([]).is_empty() assert Batch.stack([]).is_empty() b1 = Batch(e=[4, 5], d=6) b2 = Batch(e=[4, 6]) with pytest.raises(ValueError): Batch.cat([b1, b2]) with pytest.raises(ValueError): Batch.stack([b1, b2], axis=1)
def test_cat(data): """Test cat""" for i in range(10000): Batch.cat((data['batch0'], data['batch0'])) data['batchs1'][i].cat_(data['batch0'])
def test_multibuf_hdf5(): size = 100 buffers = { "vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]), "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size) } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = 'cuda' if torch.cuda.is_available() else 'cpu' info_t = torch.tensor([1.]).to(device) for i in range(4): kwargs = { 'obs': Batch(index=np.array([i])), 'act': i, 'rew': np.array([1, 2]), 'done': i % 3 == 2, 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, } buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), buffer_ids=[0, 1, 2]) buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), cached_buffer_ids=[0, 1, 2]) # save paths = {} for k, buf in buffers.items(): f, path = tempfile.mkstemp(suffix='.hdf5') os.close(f) buf.save_hdf5(path) paths[k] = path # load replay buffer _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()} # compare for k in buffers.keys(): assert len(_buffers[k]) == len(buffers[k]) assert np.allclose(_buffers[k].act, buffers[k].act) assert _buffers[k].stack_num == buffers[k].stack_num assert _buffers[k].maxsize == buffers[k].maxsize assert np.all(_buffers[k]._indices == buffers[k]._indices) # check shallow copy in ReplayBufferManager for k in ["vector", "cached"]: buffers[k].info.number.n[0] = -100 assert buffers[k].buffers[0].info.number.n[0] == -100 # check if still behave normally for k in ["vector", "cached"]: kwargs = { 'obs': Batch(index=np.array([5])), 'act': 5, 'rew': np.array([2, 1]), 'done': False, 'info': {"number": {"n": i}, 'Timelimit.truncate': True}, } buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]])) act = np.zeros(buffers[k].maxsize) if k == "vector": act[np.arange(5)] = np.array([0, 1, 2, 3, 5]) act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5]) act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5]) act[size * 3] = 5 elif k == "cached": act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]) act[np.arange(3) + size] = np.array([3, 5, 2]) act[np.arange(3) + size * 2] = np.array([3, 5, 2]) act[np.arange(3) + size * 3] = np.array([3, 5, 2]) act[size * 4] = 5 assert np.allclose(buffers[k].act, act) for path in paths.values(): os.remove(path)