def test_episode_indices_when_inserting(self): """ Tests if episodes indices and counts are set correctly when inserting terminals. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=True) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Internal memory variables. ring_buffer_variables = ring_buffer.get_variables( self.ring_buffer_variables, global_scope=False) buffer_size = ring_buffer_variables["size"] buffer_index = ring_buffer_variables["index"] num_episodes = ring_buffer_variables["num-episodes"] episode_indices = ring_buffer_variables["episode-indices"] # First, we insert a single terminal record. observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) # One episode should be present. self.assertEqual(num_episodes_value, 1) # However, the index of that episode is 0, so we cannot fetch it. self.assertEqual(sum(episode_index_values), 0) # Next, we insert 1 non-terminal, then 1 terminal element. observation = non_terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Now, we expect to have 2 episodes with episode indices at 0 and 2. size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) print('Episode indices after = {}'.format(episode_index_values)) self.assertEqual(num_episodes_value, 2) self.assertEqual(episode_index_values[1], 2)
def test_episode_fetching(self): """ Test if we can accurately fetch most recent episodes. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=True) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Insert 2 non-terminals, 1 terminal observation = non_terminal_records(self.record_space, 2) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # We should now be able to retrieve one episode of length 3. episode = test.test(("get_episodes", 1), expected_outputs=None) self.assertTrue(len(episode['reward']) == 2) # We should not be able to retrieve two episodes, and still return just one. episode = test.test(("get_episodes", 2), expected_outputs=None) self.assertTrue(len(episode['reward']) == 2) # Insert 7 non-terminals, 1 terminal -> last terminal is now at buffer index 0 as # we inserted 3 + 8 = 11 elements in total. observation = non_terminal_records(self.record_space, 7) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Check if we can fetch 2 episodes: episodes = test.test(("get_episodes", 2), expected_outputs=None) # We now expect to have retrieved: # - 10 time steps # - 2 terminal values 1 # - Terminal values spaced apart 1 index due to the insertion order self.assertEqual(len(episodes['terminals']), self.capacity) self.assertEqual(episodes['terminals'][0], True) self.assertEqual(episodes['terminals'][2], True)
def test_episode_indices_when_inserting(self): """ Tests if episodes indices and counts are set correctly when inserting terminals. """ for backend in (None, "python"): ring_buffer = RingBuffer(capacity=self.capacity, backend=backend) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # First, we insert a single terminal record. observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Internal memory variables. ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] # One episode should be present. self.assertEqual(num_episodes_value, 1) # However, the index of that episode is 0, so we cannot fetch it. self.assertEqual(sum(episode_index_values), 0) # Next, we insert 1 non-terminal, then 1 terminal element. observation = non_terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Now, we expect to have 2 episodes with episode indices at 0 and 2. ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] print('Episode indices after = {}'.format(episode_index_values)) self.assertEqual(num_episodes_value, 2) self.assertEqual(episode_index_values[1], 2)
def test_only_terminal_with_episodes(self): """ Edge case: What if only terminals are inserted when episode semantics are enabled? """ ring_buffer = RingBuffer(capacity=self.capacity) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) ring_buffer_variables = ring_buffer.get_variables(self.ring_buffer_variables, global_scope=False) num_episodes = ring_buffer_variables["num-episodes"] episode_indices = ring_buffer_variables["episode-indices"] observation = terminal_records(self.record_space, self.capacity) test.test(("insert_records", observation), expected_outputs=None) num_episodes_value, episode_index_values = test.read_variable_values(num_episodes, episode_indices) self.assertEqual(num_episodes_value, self.capacity) # Every episode index should correspond to its position for i in range_(self.capacity): self.assertEqual(episode_index_values[i], i)
def test_episode_fetching(self): """ Test if we can accurately fetch most recent episodes. """ for backend in (None, "python"): ring_buffer = RingBuffer(capacity=self.capacity, backend=backend) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Insert 2 non-terminals, 1 terminal observation = non_terminal_records(self.record_space, 2) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] # One episode. self.assertEqual(num_episodes_value, 1) expected_indices = [0] * self.capacity expected_indices[0] = 2 recursive_assert_almost_equal(episode_index_values, expected_indices) # We should now be able to retrieve one episode of length 3. episode = test.test(("get_episodes", 1), expected_outputs=None) expected_terminals = [0, 0, 1] recursive_assert_almost_equal(episode["terminals"], expected_terminals) # We should not be able to retrieve two episodes, and still return just one. episode = test.test(("get_episodes", 2), expected_outputs=None) expected_terminals = [0, 0, 1] recursive_assert_almost_equal(episode["terminals"], expected_terminals) # Insert 7 non-terminals. observation = non_terminal_records(self.record_space, 7) test.test(("insert_records", observation), expected_outputs=None) ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) index_value = ring_buffer_variables["index"] episode_index_values = ring_buffer_variables["episode-indices"] # Episode indices should not have changed. expected_indices[0] = 2 recursive_assert_almost_equal(episode_index_values, expected_indices) # Inserted 2 non-terminal, 1 terminal, 7 non-terminal at capacity 10 -> should be at 0 again. self.assertEqual(index_value, 0) # Now inserting one terminal so the terminal buffer has layout [1 0 1 0 0 0 0 0 0 0] observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Episode indices: ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) num_episodes_value = ring_buffer_variables["num-episodes"] recursive_assert_almost_equal(num_episodes_value, 2) # # Check if we can fetch 2 episodes: episodes = test.test(("get_episodes", 2), expected_outputs=None) # # # We now expect to have retrieved: # # - 10 time steps # # - 2 terminal values 1 # # - Terminal values spaced apart 1 index due to the insertion order self.assertEqual(len(episodes['terminals']), self.capacity) self.assertEqual(episodes['terminals'][0], True) self.assertEqual(episodes['terminals'][2], True)