def __call__(self, batch: SampleBatchType): x = 0 for policy_id, s in batch.policy_batches.items(): if policy_id in self.policies_to_train: for row in s.rows(): flag = row["mode"] == MODE.best_response.value if flag: # Transition must be inserted in the reservoir buffer self.reservoir_buffers.buffers[policy_id].add( pack_if_needed(row["obs"]), row["actions"]) self.replay_buffers.steps[policy_id] += 1 bb = SampleBatch({ 'obs': row["obs"].reshape(1, -1), 'actions': row['actions'].reshape(1, -1), 'rewards': row['rewards'].reshape(1, -1), 'new_obs': row['new_obs'].reshape(1, -1), 'dones': np.array([row['dones']]), "eps_id": np.array([row['eps_id']]), 'unroll_id': np.array([row['unroll_id']]), 'agent_index': np.array([row['agent_index']]) }) bb.compress(bulk=True) self.replay_buffers.buffers[policy_id].add_batch(bb) self.reservoir_buffers.steps[policy_id] += 1 return batch
def __call__(self, batch: SampleBatchType): for policy_id, s in batch.policy_batches.items(): for row in s.rows(): b = {} for k, v in row.items(): if not isinstance(v, np.ndarray): b[k] = np.array([v]) else: b[k] = v.reshape(1, -1) b = SampleBatch(b) b.compress(bulk=True) self.replay_buffers.buffers[policy_id].add_batch(b) self.replay_buffers.steps[policy_id] += 1 return batch
def test_compression(self): """Tests, whether compression and decompression work properly.""" s1 = SampleBatch({ "a": np.array([1, 2, 3, 2, 3, 4]), "b": { "c": np.array([4, 5, 6, 5, 6, 7]) }, }) # Test, whether compressing happens in-place. s1.compress(columns={"a", "b"}, bulk=True) self.assertTrue(is_compressed(s1["a"])) self.assertTrue(is_compressed(s1["b"]["c"])) self.assertTrue(isinstance(s1["b"], dict)) # Test, whether de-compressing happens in-place. s1.decompress_if_needed(columns={"a", "b"}) check(s1["a"], [1, 2, 3, 2, 3, 4]) check(s1["b"]["c"], [4, 5, 6, 5, 6, 7]) it = s1.rows() next(it) check(next(it), {"a": 2, "b": {"c": 5}})
def __call__(self, batch: SampleBatchType): x = 0 for policy_id, s in batch.policy_batches.items(): if policy_id in self.policies_to_train: for row in s.rows(): if row["mode"] == MODE.best_response.value: # Transition must be inserted in the reservoir buffer self.reservoir_buffers.buffers[policy_id].add( pack_if_needed(row["obs"]), row["actions"]) self.replay_buffers.steps[policy_id] += 1 episode_ids = np.unique(s['eps_id']) for ep_id in episode_ids: sample_ids = np.where(s["eps_id"] == ep_id) bb = SampleBatch({ 'obs': s["obs"][sample_ids], 'actions': s['actions'][sample_ids], 'rewards': s['rewards'][sample_ids], 'new_obs': s['new_obs'][sample_ids], 'dones': np.array(s['dones'][sample_ids]), "eps_id": np.array(s['eps_id'][sample_ids]), 'unroll_id': np.array(s['unroll_id'][sample_ids]), 'agent_index': np.array(s['agent_index'][sample_ids]) }) bb.compress(bulk=True) self.replay_buffers.buffers[policy_id].add_batch(bb) self.reservoir_buffers.steps[policy_id] += bb.count return batch