def test_capacity_no_episodes(self): """ Tests if insert correctly manages capacity, no episode indices updated.. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=False) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces_no_episodes) # Internal state variables. memory_variables = ring_buffer.get_variables(self.memory_variables, global_scope=False) buffer_size = memory_variables['size'] buffer_index = memory_variables['index'] size_value, index_value = test.read_variable_values( buffer_size, buffer_index) # Assert indices 0 before insert. self.assertEqual(size_value, 0) self.assertEqual(index_value, 0) # Insert one more element than capacity observation = self.record_space.sample(size=self.capacity + 1) test.test(("insert_records", observation), expected_outputs=None) size_value, index_value = test.read_variable_values( buffer_size, buffer_index) # Size should be equivalent to capacity when full. self.assertEqual(size_value, self.capacity) # Index should be one over capacity due to modulo. self.assertEqual(index_value, 1)
def test_capacity_with_episodes(self): """ Tests if inserts of non-terminals work when turning on episode semantics. Note that this does not test episode semantics itself, which are tested below. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=True) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Internal memory variables. ring_buffer_variables = ring_buffer.get_variables( self.ring_buffer_variables, global_scope=False) buffer_size = ring_buffer_variables["size"] buffer_index = ring_buffer_variables["index"] num_episodes = ring_buffer_variables["num-episodes"] episode_indices = ring_buffer_variables["episode-indices"] size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) # Assert indices 0 before insert. self.assertEqual(size_value, 0) self.assertEqual(index_value, 0) self.assertEqual(num_episodes_value, 0) self.assertEqual(np.sum(episode_index_values), 0) # Insert one more element than capacity. Note: this is different than # replay test because due to episode semantics, it matters if # these are terminal or not. This tests if episode index updating # causes problems if none of the inserted elements are terminal. observation = non_terminal_records(self.record_space, self.capacity + 1) test.test(("insert_records", observation), expected_outputs=None) size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) # Size should be equivalent to capacity when full. self.assertEqual(size_value, self.capacity) # Index should be one over capacity due to modulo. self.assertEqual(index_value, 1) self.assertEqual(num_episodes_value, 0) self.assertEqual(np.sum(episode_index_values), 0)
def test_only_terminal_with_episodes(self): """ Edge case: What if only terminals are inserted when episode semantics are enabled? """ ring_buffer = RingBuffer(capacity=self.capacity) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) ring_buffer_variables = ring_buffer.get_variables(self.ring_buffer_variables, global_scope=False) num_episodes = ring_buffer_variables["num-episodes"] episode_indices = ring_buffer_variables["episode-indices"] observation = terminal_records(self.record_space, self.capacity) test.test(("insert_records", observation), expected_outputs=None) num_episodes_value, episode_index_values = test.read_variable_values(num_episodes, episode_indices) self.assertEqual(num_episodes_value, self.capacity) # Every episode index should correspond to its position for i in range_(self.capacity): self.assertEqual(episode_index_values[i], i)
def test_episode_indices_when_inserting(self): """ Tests if episodes indices and counts are set correctly when inserting terminals. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=True) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Internal memory variables. ring_buffer_variables = ring_buffer.get_variables( self.ring_buffer_variables, global_scope=False) buffer_size = ring_buffer_variables["size"] buffer_index = ring_buffer_variables["index"] num_episodes = ring_buffer_variables["num-episodes"] episode_indices = ring_buffer_variables["episode-indices"] # First, we insert a single terminal record. observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) # One episode should be present. self.assertEqual(num_episodes_value, 1) # However, the index of that episode is 0, so we cannot fetch it. self.assertEqual(sum(episode_index_values), 0) # Next, we insert 1 non-terminal, then 1 terminal element. observation = non_terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Now, we expect to have 2 episodes with episode indices at 0 and 2. size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) print('Episode indices after = {}'.format(episode_index_values)) self.assertEqual(num_episodes_value, 2) self.assertEqual(episode_index_values[1], 2)
def test_episode_fetching(self): """ Test if we can accurately fetch most recent episodes. """ ring_buffer = RingBuffer(capacity=self.capacity) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Insert 2 non-terminals, 1 terminal observation = non_terminal_records(self.record_space, 2) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) ring_buffer_variables = ring_buffer.get_variables(self.ring_buffer_variables, global_scope=False) num_episodes = ring_buffer_variables["num-episodes"] episode_indices = ring_buffer_variables["episode-indices"] buffer_index = ring_buffer_variables["index"] num_episodes_value, episode_index_values = test.read_variable_values(num_episodes, episode_indices) # One episode. self.assertEqual(num_episodes_value, 1) expected_indices = [0] * self.capacity expected_indices[0] = 2 recursive_assert_almost_equal(episode_index_values, expected_indices) # We should now be able to retrieve one episode of length 3. episode = test.test(("get_episodes", 1), expected_outputs=None) expected_terminals = [0, 0, 1] recursive_assert_almost_equal(episode["terminals"], expected_terminals) # We should not be able to retrieve two episodes, and still return just one. episode = test.test(("get_episodes", 2), expected_outputs=None) expected_terminals = [0, 0, 1] recursive_assert_almost_equal(episode["terminals"], expected_terminals) # Insert 7 non-terminals. observation = non_terminal_records(self.record_space, 7) test.test(("insert_records", observation), expected_outputs=None) num_episodes_value, episode_index_values, buffer_val = test.read_variable_values(num_episodes, episode_indices, buffer_index) # Episode indices should not have changed. expected_indices[0] = 2 recursive_assert_almost_equal(episode_index_values, expected_indices) # Inserted 2 non-terminal, 1 terminal, 7 non-terminal at capacity 10 -> should be at 0 again. self.assertEqual(buffer_val, 0) # Now inserting one terminal so the terminal buffer has layout [1 0 1 0 0 0 0 0 0 0] observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Episode indices: num_episodes_value, episode_index_values = test.read_variable_values(num_episodes, episode_indices) recursive_assert_almost_equal(num_episodes_value, 2) # # Check if we can fetch 2 episodes: episodes = test.test(("get_episodes", 2), expected_outputs=None) # # # We now expect to have retrieved: # # - 10 time steps # # - 2 terminal values 1 # # - Terminal values spaced apart 1 index due to the insertion order self.assertEqual(len(episodes['terminals']), self.capacity) self.assertEqual(episodes['terminals'][0], True) self.assertEqual(episodes['terminals'][2], True)