def test_replay(self) -> None: # One-dimensional log = np.arange(0, 10) r = ReplayRandomGenerator(log) for idx in range(10): val = r() self.assertEqual(val, np.array(log[idx])) # The log ran out at this point, so should throw self.assertRaises(Exception, r) self.assertRaises(Exception, r)
def test_replay_multidim(self) -> None: log = np.array([[0, 1], [2, 3]]) r = ReplayRandomGenerator(log) self.assertTrue(np.array_equal(r(), np.array([0, 1]))) self.assertTrue(np.array_equal(r(), np.array([2, 3]))) self.assertRaises(IndexError, r)