def test_capacity_no_episodes(self): """ Tests if insert correctly manages capacity, no episode indices updated.. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=False) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces_no_episodes) # Internal state variables. memory_variables = ring_buffer.get_variables(self.memory_variables, global_scope=False) buffer_size = memory_variables['size'] buffer_index = memory_variables['index'] size_value, index_value = test.read_variable_values( buffer_size, buffer_index) # Assert indices 0 before insert. self.assertEqual(size_value, 0) self.assertEqual(index_value, 0) # Insert one more element than capacity observation = self.record_space.sample(size=self.capacity + 1) test.test(("insert_records", observation), expected_outputs=None) size_value, index_value = test.read_variable_values( buffer_size, buffer_index) # Size should be equivalent to capacity when full. self.assertEqual(size_value, self.capacity) # Index should be one over capacity due to modulo. self.assertEqual(index_value, 1)
def test_latest_batch(self): """ Tests if we can fetch latest steps. """ for backend in (None, "python"): ring_buffer = RingBuffer(capacity=self.capacity, backend=backend) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Insert 5 random elements. observation = non_terminal_records(self.record_space, 5) test.test(("insert_records", observation), expected_outputs=None) # First, test if the basic computation works. batch = test.test(("get_records", 5), expected_outputs=None) recursive_assert_almost_equal(batch, observation) # Next, insert capacity more elements: observation = non_terminal_records(self.record_space, self.capacity) test.test(("insert_records", observation), expected_outputs=None) # If we now fetch capacity elements, we expect to see exactly the last 10. batch = test.test(("get_records", self.capacity), expected_outputs=None) recursive_assert_almost_equal(batch, observation) # If we fetch n elements, we expect to see exactly the last n. for last_n in range(1, 6): batch = test.test(("get_records", last_n), expected_outputs=None) recursive_assert_almost_equal(batch["actions"]["action1"], observation["actions"]["action1"][-last_n:]) recursive_assert_almost_equal(batch["states"]["state2"], observation["states"]["state2"][-last_n:]) recursive_assert_almost_equal(batch["terminals"], observation["terminals"][-last_n:])
def test_latest_batch(self): """ Tests if we can fetch latest steps. """ ring_buffer = RingBuffer(capacity=self.capacity) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Insert 5 random elements. observation = non_terminal_records(self.record_space, 5) test.test(("insert_records", observation), expected_outputs=None) # First, test if the basic computation works. batch = test.test(("get_records", 5), expected_outputs=None) self.assertEqual(len(batch['terminals']), 5) # Next, insert capacity more elements: observation = non_terminal_records(self.record_space, self.capacity) test.test(("insert_records", observation), expected_outputs=None) # If we now fetch capacity elements, we expect to see exactly the last 10. batch = test.test(("get_records", self.capacity), expected_outputs=None) # Assert every inserted element is contained, even if not in same order: retrieved_action = batch['actions']['action1'] for action_value in observation['actions']['action1']: self.assertTrue(action_value in retrieved_action)
def test_capacity_with_episodes(self): """ Tests if inserts of non-terminals work when turning on episode semantics. Note that this does not test episode semantics itself, which are tested below. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=True) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Internal memory variables. ring_buffer_variables = ring_buffer.get_variables( self.ring_buffer_variables, global_scope=False) buffer_size = ring_buffer_variables["size"] buffer_index = ring_buffer_variables["index"] num_episodes = ring_buffer_variables["num-episodes"] episode_indices = ring_buffer_variables["episode-indices"] size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) # Assert indices 0 before insert. self.assertEqual(size_value, 0) self.assertEqual(index_value, 0) self.assertEqual(num_episodes_value, 0) self.assertEqual(np.sum(episode_index_values), 0) # Insert one more element than capacity. Note: this is different than # replay test because due to episode semantics, it matters if # these are terminal or not. This tests if episode index updating # causes problems if none of the inserted elements are terminal. observation = non_terminal_records(self.record_space, self.capacity + 1) test.test(("insert_records", observation), expected_outputs=None) size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) # Size should be equivalent to capacity when full. self.assertEqual(size_value, self.capacity) # Index should be one over capacity due to modulo. self.assertEqual(index_value, 1) self.assertEqual(num_episodes_value, 0) self.assertEqual(np.sum(episode_index_values), 0)
def test_only_terminal_with_episodes(self): """ Edge case: What if only terminals are inserted when episode semantics are enabled? """ ring_buffer = RingBuffer(capacity=self.capacity) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) ring_buffer_variables = ring_buffer.get_variables(self.ring_buffer_variables, global_scope=False) num_episodes = ring_buffer_variables["num-episodes"] episode_indices = ring_buffer_variables["episode-indices"] observation = terminal_records(self.record_space, self.capacity) test.test(("insert_records", observation), expected_outputs=None) num_episodes_value, episode_index_values = test.read_variable_values(num_episodes, episode_indices) self.assertEqual(num_episodes_value, self.capacity) # Every episode index should correspond to its position for i in range_(self.capacity): self.assertEqual(episode_index_values[i], i)
def test_capacity_with_episodes(self): """ Tests if inserts of non-terminals work. Note that this does not test episode semantics itself, which are tested below. """ ring_buffer = RingBuffer(capacity=self.capacity) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Internal memory variables. ring_buffer_variables = test.get_variable_values( ring_buffer, self.ring_buffer_variables) size_value = ring_buffer_variables["size"] index_value = ring_buffer_variables["index"] num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] # Assert indices 0 before insert. self.assertEqual(size_value, 0) self.assertEqual(index_value, 0) self.assertEqual(num_episodes_value, 0) self.assertEqual(np.sum(episode_index_values), 0) # Insert one more element than capacity. Note: this is different than # replay test because due to episode semantics, it matters if # these are terminal or not. This tests if episode index updating # causes problems if none of the inserted elements are terminal. observation = non_terminal_records(self.record_space, self.capacity + 1) test.test(("insert_records", observation), expected_outputs=None) ring_buffer_variables = test.get_variable_values( ring_buffer, self.ring_buffer_variables) size_value = ring_buffer_variables["size"] index_value = ring_buffer_variables["index"] num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] # Size should be equivalent to capacity when full. self.assertEqual(size_value, self.capacity) # Index should be one over capacity due to modulo. self.assertEqual(index_value, 1) self.assertEqual(num_episodes_value, 0) self.assertEqual(np.sum(episode_index_values), 0) # If we fetch n elements, we expect to see exactly the last n. for last_n in range(1, 6): batch = test.test(("get_records", last_n), expected_outputs=None) recursive_assert_almost_equal( batch["actions"]["action1"], observation["actions"]["action1"][-last_n:]) recursive_assert_almost_equal( batch["states"]["state2"], observation["states"]["state2"][-last_n:]) recursive_assert_almost_equal(batch["terminals"], observation["terminals"][-last_n:])
def test_episode_indices_when_inserting(self): """ Tests if episodes indices and counts are set correctly when inserting terminals. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=True) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Internal memory variables. ring_buffer_variables = ring_buffer.get_variables( self.ring_buffer_variables, global_scope=False) buffer_size = ring_buffer_variables["size"] buffer_index = ring_buffer_variables["index"] num_episodes = ring_buffer_variables["num-episodes"] episode_indices = ring_buffer_variables["episode-indices"] # First, we insert a single terminal record. observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) # One episode should be present. self.assertEqual(num_episodes_value, 1) # However, the index of that episode is 0, so we cannot fetch it. self.assertEqual(sum(episode_index_values), 0) # Next, we insert 1 non-terminal, then 1 terminal element. observation = non_terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Now, we expect to have 2 episodes with episode indices at 0 and 2. size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) print('Episode indices after = {}'.format(episode_index_values)) self.assertEqual(num_episodes_value, 2) self.assertEqual(episode_index_values[1], 2)
def test_insert_no_episodes(self): """ Simply tests insert op without checking internal logic, episode semantics disabled. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=False) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces_no_episodes) observation = self.record_space.sample(size=1) test.test(("insert_records", observation), expected_outputs=None) observation = self.record_space.sample(size=10) test.test(("insert_records", observation), expected_outputs=None)
def test_episode_fetching(self): """ Test if we can accurately fetch most recent episodes. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=True) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Insert 2 non-terminals, 1 terminal observation = non_terminal_records(self.record_space, 2) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # We should now be able to retrieve one episode of length 3. episode = test.test(("get_episodes", 1), expected_outputs=None) self.assertTrue(len(episode['reward']) == 2) # We should not be able to retrieve two episodes, and still return just one. episode = test.test(("get_episodes", 2), expected_outputs=None) self.assertTrue(len(episode['reward']) == 2) # Insert 7 non-terminals, 1 terminal -> last terminal is now at buffer index 0 as # we inserted 3 + 8 = 11 elements in total. observation = non_terminal_records(self.record_space, 7) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Check if we can fetch 2 episodes: episodes = test.test(("get_episodes", 2), expected_outputs=None) # We now expect to have retrieved: # - 10 time steps # - 2 terminal values 1 # - Terminal values spaced apart 1 index due to the insertion order self.assertEqual(len(episodes['terminals']), self.capacity) self.assertEqual(episodes['terminals'][0], True) self.assertEqual(episodes['terminals'][2], True)
def test_episode_indices_when_inserting(self): """ Tests if episodes indices and counts are set correctly when inserting terminals. """ for backend in (None, "python"): ring_buffer = RingBuffer(capacity=self.capacity, backend=backend) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # First, we insert a single terminal record. observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Internal memory variables. ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] # One episode should be present. self.assertEqual(num_episodes_value, 1) # However, the index of that episode is 0, so we cannot fetch it. self.assertEqual(sum(episode_index_values), 0) # Next, we insert 1 non-terminal, then 1 terminal element. observation = non_terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Now, we expect to have 2 episodes with episode indices at 0 and 2. ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] print('Episode indices after = {}'.format(episode_index_values)) self.assertEqual(num_episodes_value, 2) self.assertEqual(episode_index_values[1], 2)
def test_episode_fetching(self): """ Test if we can accurately fetch most recent episodes. """ for backend in (None, "python"): ring_buffer = RingBuffer(capacity=self.capacity, backend=backend) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Insert 2 non-terminals, 1 terminal observation = non_terminal_records(self.record_space, 2) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] # One episode. self.assertEqual(num_episodes_value, 1) expected_indices = [0] * self.capacity expected_indices[0] = 2 recursive_assert_almost_equal(episode_index_values, expected_indices) # We should now be able to retrieve one episode of length 3. episode = test.test(("get_episodes", 1), expected_outputs=None) expected_terminals = [0, 0, 1] recursive_assert_almost_equal(episode["terminals"], expected_terminals) # We should not be able to retrieve two episodes, and still return just one. episode = test.test(("get_episodes", 2), expected_outputs=None) expected_terminals = [0, 0, 1] recursive_assert_almost_equal(episode["terminals"], expected_terminals) # Insert 7 non-terminals. observation = non_terminal_records(self.record_space, 7) test.test(("insert_records", observation), expected_outputs=None) ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) index_value = ring_buffer_variables["index"] episode_index_values = ring_buffer_variables["episode-indices"] # Episode indices should not have changed. expected_indices[0] = 2 recursive_assert_almost_equal(episode_index_values, expected_indices) # Inserted 2 non-terminal, 1 terminal, 7 non-terminal at capacity 10 -> should be at 0 again. self.assertEqual(index_value, 0) # Now inserting one terminal so the terminal buffer has layout [1 0 1 0 0 0 0 0 0 0] observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Episode indices: ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) num_episodes_value = ring_buffer_variables["num-episodes"] recursive_assert_almost_equal(num_episodes_value, 2) # # Check if we can fetch 2 episodes: episodes = test.test(("get_episodes", 2), expected_outputs=None) # # # We now expect to have retrieved: # # - 10 time steps # # - 2 terminal values 1 # # - Terminal values spaced apart 1 index due to the insertion order self.assertEqual(len(episodes['terminals']), self.capacity) self.assertEqual(episodes['terminals'][0], True) self.assertEqual(episodes['terminals'][2], True)