def test_reservoir_buffer_max_capacity(self): reservoir_buffer = nfsp.ReservoirBuffer(reservoir_buffer_capacity=2) reservoir_buffer.add("entry1") reservoir_buffer.add("entry2") reservoir_buffer.add("entry3") self.assertEqual(len(reservoir_buffer), 2)
def test_reservoir_buffer_add(self): reservoir_buffer = nfsp.ReservoirBuffer(reservoir_buffer_capacity=10) self.assertEqual(len(reservoir_buffer), 0) reservoir_buffer.add("entry1") self.assertEqual(len(reservoir_buffer), 1) reservoir_buffer.add("entry2") self.assertEqual(len(reservoir_buffer), 2) self.assertIn("entry1", reservoir_buffer) self.assertIn("entry2", reservoir_buffer)
def test_reservoir_buffer_sample(self): replay_buffer = nfsp.ReservoirBuffer(reservoir_buffer_capacity=3) replay_buffer.add("entry1") replay_buffer.add("entry2") replay_buffer.add("entry3") samples = replay_buffer.sample(3) self.assertIn("entry1", samples) self.assertIn("entry2", samples) self.assertIn("entry3", samples)
def test_reservoir_uniform(self): size = 10 max_value = 100 num_trials = 1000 expected_count = 1. / max_value * size * num_trials reservoir_buffer = nfsp.ReservoirBuffer(reservoir_buffer_capacity=size) counter = collections.Counter() for _ in range(num_trials): reservoir_buffer.clear() for idx in range(max_value): reservoir_buffer.add(idx) data = reservoir_buffer.sample(size) counter.update(data) # Tests the null hypothesis (H0) that data has the given frequencies. # We reject the null hypothesis if we get a p-value below our threshold. pvalue = stats.chisquare(list(counter.values()), expected_count).pvalue self.assertGreater(pvalue, 0.05) # We cannot reject H0.