Exemplo n.º 1
0
 def testBatchIds(self):
     ev = RolloutWorker(env_creator=lambda _: gym.make("CartPole-v0"),
                        policy=MockPolicy)
     batch1 = ev.sample()
     batch2 = ev.sample()
     self.assertEqual(len(set(batch1["unroll_id"])), 1)
     self.assertEqual(len(set(batch2["unroll_id"])), 1)
     self.assertEqual(
         len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2)
Exemplo n.º 2
0
 def test_concat(self):
     b1 = SampleBatch({"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])})
     b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])})
     b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])})
     b12 = b1.concat(b2)
     self.assertEqual(b12["a"].tolist(), [1, 2, 3, 1])
     self.assertEqual(b12["b"].tolist(), [4, 5, 6, 4])
     b = SampleBatch.concat_samples([b1, b2, b3])
     self.assertEqual(b["a"].tolist(), [1, 2, 3, 1, 1])
     self.assertEqual(b["b"].tolist(), [4, 5, 6, 4, 5])
Exemplo n.º 3
0
 def test_batch_ids(self):
     ev = RolloutWorker(env_creator=lambda _: gym.make("CartPole-v0"),
                        policy_spec=MockPolicy,
                        rollout_fragment_length=1)
     batch1 = ev.sample()
     batch2 = ev.sample()
     self.assertEqual(len(set(batch1["unroll_id"])), 1)
     self.assertEqual(len(set(batch2["unroll_id"])), 1)
     self.assertEqual(
         len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2)
     ev.stop()
Exemplo n.º 4
0
    def test_concat(self):
        """Tests, SampleBatches.concat() and ...concat_samples()."""
        s1 = SampleBatch({
            "a": np.array([1, 2, 3]),
            "b": {
                "c": np.array([4, 5, 6])
            },
        })
        s2 = SampleBatch({
            "a": np.array([2, 3, 4]),
            "b": {
                "c": np.array([5, 6, 7])
            },
        })
        concatd = SampleBatch.concat_samples([s1, s2])
        check(concatd["a"], [1, 2, 3, 2, 3, 4])
        check(concatd["b"]["c"], [4, 5, 6, 5, 6, 7])
        check(next(concatd.rows()), {"a": 1, "b": {"c": 4}})

        concatd_2 = s1.concat(s2)
        check(concatd, concatd_2)