Exemplo n.º 1
0
    def __call__(self, batch: SampleBatchType):
        x = 0
        for policy_id, s in batch.policy_batches.items():
            if policy_id in self.policies_to_train:
                for row in s.rows():
                    flag = row["mode"] == MODE.best_response.value
                    if flag:
                        # Transition must be inserted in the reservoir buffer
                        self.reservoir_buffers.buffers[policy_id].add(
                            pack_if_needed(row["obs"]), row["actions"])
                        self.replay_buffers.steps[policy_id] += 1

                    bb = SampleBatch({
                        'obs':
                        row["obs"].reshape(1, -1),
                        'actions':
                        row['actions'].reshape(1, -1),
                        'rewards':
                        row['rewards'].reshape(1, -1),
                        'new_obs':
                        row['new_obs'].reshape(1, -1),
                        'dones':
                        np.array([row['dones']]),
                        "eps_id":
                        np.array([row['eps_id']]),
                        'unroll_id':
                        np.array([row['unroll_id']]),
                        'agent_index':
                        np.array([row['agent_index']])
                    })
                    bb.compress(bulk=True)
                    self.replay_buffers.buffers[policy_id].add_batch(bb)
                    self.reservoir_buffers.steps[policy_id] += 1

        return batch
Exemplo n.º 2
0
 def __call__(self, batch: SampleBatchType):
     for policy_id, s in batch.policy_batches.items():
         for row in s.rows():
             b = {}
             for k, v in row.items():
                 if not isinstance(v, np.ndarray):
                     b[k] = np.array([v])
                 else:
                     b[k] = v.reshape(1, -1)
             b = SampleBatch(b)
             b.compress(bulk=True)
             self.replay_buffers.buffers[policy_id].add_batch(b)
             self.replay_buffers.steps[policy_id] += 1
     return batch
Exemplo n.º 3
0
    def test_compression(self):
        """Tests, whether compression and decompression work properly."""
        s1 = SampleBatch({
            "a": np.array([1, 2, 3, 2, 3, 4]),
            "b": {
                "c": np.array([4, 5, 6, 5, 6, 7])
            },
        })
        # Test, whether compressing happens in-place.
        s1.compress(columns={"a", "b"}, bulk=True)
        self.assertTrue(is_compressed(s1["a"]))
        self.assertTrue(is_compressed(s1["b"]["c"]))
        self.assertTrue(isinstance(s1["b"], dict))

        # Test, whether de-compressing happens in-place.
        s1.decompress_if_needed(columns={"a", "b"})
        check(s1["a"], [1, 2, 3, 2, 3, 4])
        check(s1["b"]["c"], [4, 5, 6, 5, 6, 7])
        it = s1.rows()
        next(it)
        check(next(it), {"a": 2, "b": {"c": 5}})
Exemplo n.º 4
0
    def __call__(self, batch: SampleBatchType):
        x = 0
        for policy_id, s in batch.policy_batches.items():
            if policy_id in self.policies_to_train:
                for row in s.rows():
                    if row["mode"] == MODE.best_response.value:
                        # Transition must be inserted in the reservoir buffer
                        self.reservoir_buffers.buffers[policy_id].add(
                            pack_if_needed(row["obs"]), row["actions"])
                        self.replay_buffers.steps[policy_id] += 1

                episode_ids = np.unique(s['eps_id'])
                for ep_id in episode_ids:
                    sample_ids = np.where(s["eps_id"] == ep_id)
                    bb = SampleBatch({
                        'obs':
                        s["obs"][sample_ids],
                        'actions':
                        s['actions'][sample_ids],
                        'rewards':
                        s['rewards'][sample_ids],
                        'new_obs':
                        s['new_obs'][sample_ids],
                        'dones':
                        np.array(s['dones'][sample_ids]),
                        "eps_id":
                        np.array(s['eps_id'][sample_ids]),
                        'unroll_id':
                        np.array(s['unroll_id'][sample_ids]),
                        'agent_index':
                        np.array(s['agent_index'][sample_ids])
                    })
                    bb.compress(bulk=True)
                    self.replay_buffers.buffers[policy_id].add_batch(bb)
                    self.reservoir_buffers.steps[policy_id] += bb.count

        return batch