def test_latest_batch(self): """ Tests if we can fetch latest steps. """ for backend in (None, "python"): ring_buffer = RingBuffer(capacity=self.capacity, backend=backend) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Insert 5 random elements. observation = non_terminal_records(self.record_space, 5) test.test(("insert_records", observation), expected_outputs=None) # First, test if the basic computation works. batch = test.test(("get_records", 5), expected_outputs=None) recursive_assert_almost_equal(batch, observation) # Next, insert capacity more elements: observation = non_terminal_records(self.record_space, self.capacity) test.test(("insert_records", observation), expected_outputs=None) # If we now fetch capacity elements, we expect to see exactly the last 10. batch = test.test(("get_records", self.capacity), expected_outputs=None) recursive_assert_almost_equal(batch, observation) # If we fetch n elements, we expect to see exactly the last n. for last_n in range(1, 6): batch = test.test(("get_records", last_n), expected_outputs=None) recursive_assert_almost_equal(batch["actions"]["action1"], observation["actions"]["action1"][-last_n:]) recursive_assert_almost_equal(batch["states"]["state2"], observation["states"]["state2"][-last_n:]) recursive_assert_almost_equal(batch["terminals"], observation["terminals"][-last_n:])
def test_batch_retrieve(self): """ Tests if retrieval correctly manages capacity. """ memory = PrioritizedReplay(capacity=self.capacity, alpha=self.alpha, beta=self.beta) test = ComponentTest(component=memory, input_spaces=self.input_spaces) # Insert 2 Elements. observation = non_terminal_records(self.record_space, 2) test.test(("insert_records", observation), expected_outputs=None) # Assert we can now fetch 2 elements. num_records = 2 batch = test.test(("get_records", num_records), expected_outputs=None) records = batch[0] print('Result batch = {}'.format(records)) self.assertEqual(2, len(records['terminals'])) # We allow repeat indices in sampling. num_records = 5 batch = test.test(("get_records", num_records), expected_outputs=None) records = batch[0] self.assertEqual(5, len(records['terminals'])) # Now insert over capacity, note all elements here are non-terminal. observation = non_terminal_records(self.record_space, self.capacity) test.test(("insert_records", observation), expected_outputs=None) # Assert we can fetch exactly capacity elements. num_records = self.capacity batch = test.test(("get_records", num_records), expected_outputs=None) records = batch[0] self.assertEqual(self.capacity, len(records['terminals']))
def test_latest_batch(self): """ Tests if we can fetch latest steps. """ ring_buffer = RingBuffer(capacity=self.capacity) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Insert 5 random elements. observation = non_terminal_records(self.record_space, 5) test.test(("insert_records", observation), expected_outputs=None) # First, test if the basic computation works. batch = test.test(("get_records", 5), expected_outputs=None) self.assertEqual(len(batch['terminals']), 5) # Next, insert capacity more elements: observation = non_terminal_records(self.record_space, self.capacity) test.test(("insert_records", observation), expected_outputs=None) # If we now fetch capacity elements, we expect to see exactly the last 10. batch = test.test(("get_records", self.capacity), expected_outputs=None) # Assert every inserted element is contained, even if not in same order: retrieved_action = batch['actions']['action1'] for action_value in observation['actions']['action1']: self.assertTrue(action_value in retrieved_action)
def test_batch_retrieve(self): """ Tests if retrieval correctly manages capacity. """ memory = ReplayMemory( capacity=self.capacity ) test = ComponentTest(component=memory, input_spaces=self.input_spaces) # Insert 2 Elements. observation = non_terminal_records(self.record_space, 2) test.test(("insert_records", observation), expected_outputs=None) # Assert we can now fetch 2 elements. num_records = 2 batch, _, _ = test.test(("get_records", num_records), expected_outputs=None) print('Result batch = {}'.format(batch)) self.assertEqual(2, len(batch['terminals'])) # Assert next states key is there self.assertTrue('next_states' in batch) # Test duplicate sampling. num_records = 5 batch, _, _ = test.test(("get_records", num_records), expected_outputs=None) self.assertEqual(5, len(batch['terminals'])) # Now insert over capacity. observation = non_terminal_records(self.record_space, self.capacity) test.test(("insert_records", observation), expected_outputs=None) # Assert we can fetch exactly capacity elements. num_records = self.capacity batch, _, _ = test.test(("get_records", num_records), expected_outputs=None) self.assertEqual(self.capacity, len(batch['terminals']))
def test_capacity_with_episodes(self): """ Tests if inserts of non-terminals work. Note that this does not test episode semantics itself, which are tested below. """ ring_buffer = RingBuffer(capacity=self.capacity) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Internal memory variables. ring_buffer_variables = test.get_variable_values( ring_buffer, self.ring_buffer_variables) size_value = ring_buffer_variables["size"] index_value = ring_buffer_variables["index"] num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] # Assert indices 0 before insert. self.assertEqual(size_value, 0) self.assertEqual(index_value, 0) self.assertEqual(num_episodes_value, 0) self.assertEqual(np.sum(episode_index_values), 0) # Insert one more element than capacity. Note: this is different than # replay test because due to episode semantics, it matters if # these are terminal or not. This tests if episode index updating # causes problems if none of the inserted elements are terminal. observation = non_terminal_records(self.record_space, self.capacity + 1) test.test(("insert_records", observation), expected_outputs=None) ring_buffer_variables = test.get_variable_values( ring_buffer, self.ring_buffer_variables) size_value = ring_buffer_variables["size"] index_value = ring_buffer_variables["index"] num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] # Size should be equivalent to capacity when full. self.assertEqual(size_value, self.capacity) # Index should be one over capacity due to modulo. self.assertEqual(index_value, 1) self.assertEqual(num_episodes_value, 0) self.assertEqual(np.sum(episode_index_values), 0) # If we fetch n elements, we expect to see exactly the last n. for last_n in range(1, 6): batch = test.test(("get_records", last_n), expected_outputs=None) recursive_assert_almost_equal( batch["actions"]["action1"], observation["actions"]["action1"][-last_n:]) recursive_assert_almost_equal( batch["states"]["state2"], observation["states"]["state2"][-last_n:]) recursive_assert_almost_equal(batch["terminals"], observation["terminals"][-last_n:])
def test_episode_fetching(self): """ Test if we can accurately fetch most recent episodes. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=True) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Insert 2 non-terminals, 1 terminal observation = non_terminal_records(self.record_space, 2) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # We should now be able to retrieve one episode of length 3. episode = test.test(("get_episodes", 1), expected_outputs=None) self.assertTrue(len(episode['reward']) == 2) # We should not be able to retrieve two episodes, and still return just one. episode = test.test(("get_episodes", 2), expected_outputs=None) self.assertTrue(len(episode['reward']) == 2) # Insert 7 non-terminals, 1 terminal -> last terminal is now at buffer index 0 as # we inserted 3 + 8 = 11 elements in total. observation = non_terminal_records(self.record_space, 7) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Check if we can fetch 2 episodes: episodes = test.test(("get_episodes", 2), expected_outputs=None) # We now expect to have retrieved: # - 10 time steps # - 2 terminal values 1 # - Terminal values spaced apart 1 index due to the insertion order self.assertEqual(len(episodes['terminals']), self.capacity) self.assertEqual(episodes['terminals'][0], True) self.assertEqual(episodes['terminals'][2], True)
def test_capacity_with_episodes(self): """ Tests if inserts of non-terminals work when turning on episode semantics. Note that this does not test episode semantics itself, which are tested below. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=True) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Internal memory variables. ring_buffer_variables = ring_buffer.get_variables( self.ring_buffer_variables, global_scope=False) buffer_size = ring_buffer_variables["size"] buffer_index = ring_buffer_variables["index"] num_episodes = ring_buffer_variables["num-episodes"] episode_indices = ring_buffer_variables["episode-indices"] size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) # Assert indices 0 before insert. self.assertEqual(size_value, 0) self.assertEqual(index_value, 0) self.assertEqual(num_episodes_value, 0) self.assertEqual(np.sum(episode_index_values), 0) # Insert one more element than capacity. Note: this is different than # replay test because due to episode semantics, it matters if # these are terminal or not. This tests if episode index updating # causes problems if none of the inserted elements are terminal. observation = non_terminal_records(self.record_space, self.capacity + 1) test.test(("insert_records", observation), expected_outputs=None) size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) # Size should be equivalent to capacity when full. self.assertEqual(size_value, self.capacity) # Index should be one over capacity due to modulo. self.assertEqual(index_value, 1) self.assertEqual(num_episodes_value, 0) self.assertEqual(np.sum(episode_index_values), 0)
def test_update_records(self): """ Tests update records logic. """ memory = PrioritizedReplay(capacity=self.capacity) test = ComponentTest(component=memory, input_spaces=self.input_spaces) # Insert a few Elements. observation = non_terminal_records(self.record_space, 2) test.test(("insert_records", observation), expected_outputs=None) # Fetch elements and their indices. num_records = 2 batch = test.test(("get_records", num_records), expected_outputs=None) indices = batch[1] self.assertEqual(num_records, len(indices)) # 0.3, 0.5, 1.0]) input_params = [indices, np.asarray([0.1, 0.2])] # Does not return anything test.test(("update_records", input_params), expected_outputs=None)
def test_episode_indices_when_inserting(self): """ Tests if episodes indices and counts are set correctly when inserting terminals. """ ring_buffer = RingBuffer(capacity=self.capacity, episode_semantics=True) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Internal memory variables. ring_buffer_variables = ring_buffer.get_variables( self.ring_buffer_variables, global_scope=False) buffer_size = ring_buffer_variables["size"] buffer_index = ring_buffer_variables["index"] num_episodes = ring_buffer_variables["num-episodes"] episode_indices = ring_buffer_variables["episode-indices"] # First, we insert a single terminal record. observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) # One episode should be present. self.assertEqual(num_episodes_value, 1) # However, the index of that episode is 0, so we cannot fetch it. self.assertEqual(sum(episode_index_values), 0) # Next, we insert 1 non-terminal, then 1 terminal element. observation = non_terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Now, we expect to have 2 episodes with episode indices at 0 and 2. size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values( buffer_size, buffer_index, num_episodes, episode_indices) print('Episode indices after = {}'.format(episode_index_values)) self.assertEqual(num_episodes_value, 2) self.assertEqual(episode_index_values[1], 2)
def test_episode_indices_when_inserting(self): """ Tests if episodes indices and counts are set correctly when inserting terminals. """ for backend in (None, "python"): ring_buffer = RingBuffer(capacity=self.capacity, backend=backend) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # First, we insert a single terminal record. observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Internal memory variables. ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] # One episode should be present. self.assertEqual(num_episodes_value, 1) # However, the index of that episode is 0, so we cannot fetch it. self.assertEqual(sum(episode_index_values), 0) # Next, we insert 1 non-terminal, then 1 terminal element. observation = non_terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Now, we expect to have 2 episodes with episode indices at 0 and 2. ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] print('Episode indices after = {}'.format(episode_index_values)) self.assertEqual(num_episodes_value, 2) self.assertEqual(episode_index_values[1], 2)
def test_episode_fetching(self): """ Test if we can accurately fetch most recent episodes. """ for backend in (None, "python"): ring_buffer = RingBuffer(capacity=self.capacity, backend=backend) test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces) # Insert 2 non-terminals, 1 terminal observation = non_terminal_records(self.record_space, 2) test.test(("insert_records", observation), expected_outputs=None) observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) num_episodes_value = ring_buffer_variables["num-episodes"] episode_index_values = ring_buffer_variables["episode-indices"] # One episode. self.assertEqual(num_episodes_value, 1) expected_indices = [0] * self.capacity expected_indices[0] = 2 recursive_assert_almost_equal(episode_index_values, expected_indices) # We should now be able to retrieve one episode of length 3. episode = test.test(("get_episodes", 1), expected_outputs=None) expected_terminals = [0, 0, 1] recursive_assert_almost_equal(episode["terminals"], expected_terminals) # We should not be able to retrieve two episodes, and still return just one. episode = test.test(("get_episodes", 2), expected_outputs=None) expected_terminals = [0, 0, 1] recursive_assert_almost_equal(episode["terminals"], expected_terminals) # Insert 7 non-terminals. observation = non_terminal_records(self.record_space, 7) test.test(("insert_records", observation), expected_outputs=None) ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) index_value = ring_buffer_variables["index"] episode_index_values = ring_buffer_variables["episode-indices"] # Episode indices should not have changed. expected_indices[0] = 2 recursive_assert_almost_equal(episode_index_values, expected_indices) # Inserted 2 non-terminal, 1 terminal, 7 non-terminal at capacity 10 -> should be at 0 again. self.assertEqual(index_value, 0) # Now inserting one terminal so the terminal buffer has layout [1 0 1 0 0 0 0 0 0 0] observation = terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Episode indices: ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables) num_episodes_value = ring_buffer_variables["num-episodes"] recursive_assert_almost_equal(num_episodes_value, 2) # # Check if we can fetch 2 episodes: episodes = test.test(("get_episodes", 2), expected_outputs=None) # # # We now expect to have retrieved: # # - 10 time steps # # - 2 terminal values 1 # # - Terminal values spaced apart 1 index due to the insertion order self.assertEqual(len(episodes['terminals']), self.capacity) self.assertEqual(episodes['terminals'][0], True) self.assertEqual(episodes['terminals'][2], True)
def test_segment_tree_insert_values(self): """ Tests if segment tree inserts into correct positions. """ memory = PrioritizedReplay(capacity=self.capacity, alpha=self.alpha, beta=self.beta) test = ComponentTest(component=memory, input_spaces=self.input_spaces) priority_capacity = 1 while priority_capacity < self.capacity: priority_capacity *= 2 memory_variables = memory.get_variables( ["sum-segment-tree", "min-segment-tree"], global_scope=False) sum_segment_tree = memory_variables['sum-segment-tree'] min_segment_tree = memory_variables['min-segment-tree'] sum_segment_values, min_segment_values = test.read_variable_values( sum_segment_tree, min_segment_tree) self.assertEqual(sum(sum_segment_values), 0) self.assertEqual(sum(min_segment_values), float('inf')) self.assertEqual(len(sum_segment_values), 2 * priority_capacity) self.assertEqual(len(min_segment_values), 2 * priority_capacity) # Insert 1 Element. observation = non_terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Fetch segment tree. sum_segment_values, min_segment_values = test.read_variable_values( sum_segment_tree, min_segment_tree) # Check insert positions # Initial insert is at priority capacity print(sum_segment_values) print(min_segment_values) start = priority_capacity while start >= 1: self.assertEqual(sum_segment_values[start], 1.0) self.assertEqual(min_segment_values[start], 1.0) start = int(start / 2) # Insert another Element. observation = non_terminal_records(self.record_space, 1) test.test(("insert_records", observation), expected_outputs=None) # Fetch segment tree. sum_segment_values, min_segment_values = test.read_variable_values( sum_segment_tree, min_segment_tree) print(sum_segment_values) print(min_segment_values) # Index shifted 1 start = priority_capacity + 1 self.assertEqual(sum_segment_values[start], 1.0) self.assertEqual(min_segment_values[start], 1.0) start = int(start / 2) while start >= 1: # 1 + 1 is 2 on the segment. self.assertEqual(sum_segment_values[start], 2.0) # min is still 1. self.assertEqual(min_segment_values[start], 1.0) start = int(start / 2)