예제 #1
0
    def test_latest_batch(self):
        """
        Tests if we can fetch latest steps.
        """
        for backend in (None, "python"):
            ring_buffer = RingBuffer(capacity=self.capacity, backend=backend)
            test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces)

            # Insert 5 random elements.
            observation = non_terminal_records(self.record_space, 5)
            test.test(("insert_records", observation), expected_outputs=None)

            # First, test if the basic computation works.
            batch = test.test(("get_records", 5), expected_outputs=None)
            recursive_assert_almost_equal(batch, observation)

            # Next, insert capacity more elements:
            observation = non_terminal_records(self.record_space, self.capacity)
            test.test(("insert_records", observation), expected_outputs=None)

            # If we now fetch capacity elements, we expect to see exactly the last 10.
            batch = test.test(("get_records", self.capacity), expected_outputs=None)
            recursive_assert_almost_equal(batch, observation)

            # If we fetch n elements, we expect to see exactly the last n.
            for last_n in range(1, 6):
                batch = test.test(("get_records", last_n), expected_outputs=None)
                recursive_assert_almost_equal(batch["actions"]["action1"], observation["actions"]["action1"][-last_n:])
                recursive_assert_almost_equal(batch["states"]["state2"], observation["states"]["state2"][-last_n:])
                recursive_assert_almost_equal(batch["terminals"], observation["terminals"][-last_n:])
예제 #2
0
    def test_batch_retrieve(self):
        """
        Tests if retrieval correctly manages capacity.
        """
        memory = PrioritizedReplay(capacity=self.capacity,
                                   alpha=self.alpha,
                                   beta=self.beta)
        test = ComponentTest(component=memory, input_spaces=self.input_spaces)

        # Insert 2 Elements.
        observation = non_terminal_records(self.record_space, 2)
        test.test(("insert_records", observation), expected_outputs=None)

        # Assert we can now fetch 2 elements.
        num_records = 2
        batch = test.test(("get_records", num_records), expected_outputs=None)
        records = batch[0]
        print('Result batch = {}'.format(records))
        self.assertEqual(2, len(records['terminals']))

        # We allow repeat indices in sampling.
        num_records = 5
        batch = test.test(("get_records", num_records), expected_outputs=None)
        records = batch[0]
        self.assertEqual(5, len(records['terminals']))

        # Now insert over capacity, note all elements here are non-terminal.
        observation = non_terminal_records(self.record_space, self.capacity)
        test.test(("insert_records", observation), expected_outputs=None)

        # Assert we can fetch exactly capacity elements.
        num_records = self.capacity
        batch = test.test(("get_records", num_records), expected_outputs=None)
        records = batch[0]
        self.assertEqual(self.capacity, len(records['terminals']))
예제 #3
0
    def test_latest_batch(self):
        """
        Tests if we can fetch latest steps.
        """
        ring_buffer = RingBuffer(capacity=self.capacity)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=self.input_spaces)

        # Insert 5 random elements.
        observation = non_terminal_records(self.record_space, 5)
        test.test(("insert_records", observation), expected_outputs=None)

        # First, test if the basic computation works.
        batch = test.test(("get_records", 5), expected_outputs=None)
        self.assertEqual(len(batch['terminals']), 5)

        # Next, insert capacity more elements:
        observation = non_terminal_records(self.record_space, self.capacity)
        test.test(("insert_records", observation), expected_outputs=None)

        # If we now fetch capacity elements, we expect to see exactly the last 10.
        batch = test.test(("get_records", self.capacity),
                          expected_outputs=None)

        # Assert every inserted element is contained, even if not in same order:
        retrieved_action = batch['actions']['action1']
        for action_value in observation['actions']['action1']:
            self.assertTrue(action_value in retrieved_action)
예제 #4
0
    def test_batch_retrieve(self):
        """
        Tests if retrieval correctly manages capacity.
        """
        memory = ReplayMemory(
            capacity=self.capacity
        )
        test = ComponentTest(component=memory, input_spaces=self.input_spaces)

        # Insert 2 Elements.
        observation = non_terminal_records(self.record_space, 2)
        test.test(("insert_records", observation), expected_outputs=None)

        # Assert we can now fetch 2 elements.
        num_records = 2
        batch, _, _ = test.test(("get_records", num_records), expected_outputs=None)
        print('Result batch = {}'.format(batch))
        self.assertEqual(2, len(batch['terminals']))
        # Assert next states key is there
        self.assertTrue('next_states' in batch)

        # Test duplicate sampling.
        num_records = 5
        batch, _, _ = test.test(("get_records", num_records), expected_outputs=None)
        self.assertEqual(5, len(batch['terminals']))

        # Now insert over capacity.
        observation = non_terminal_records(self.record_space, self.capacity)
        test.test(("insert_records", observation), expected_outputs=None)

        # Assert we can fetch exactly capacity elements.
        num_records = self.capacity
        batch, _, _ = test.test(("get_records", num_records), expected_outputs=None)
        self.assertEqual(self.capacity, len(batch['terminals']))
예제 #5
0
    def test_capacity_with_episodes(self):
        """
        Tests if inserts of non-terminals work.

        Note that this does not test episode semantics itself, which are tested below.
        """
        ring_buffer = RingBuffer(capacity=self.capacity)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=self.input_spaces)
        # Internal memory variables.
        ring_buffer_variables = test.get_variable_values(
            ring_buffer, self.ring_buffer_variables)
        size_value = ring_buffer_variables["size"]
        index_value = ring_buffer_variables["index"]
        num_episodes_value = ring_buffer_variables["num-episodes"]
        episode_index_values = ring_buffer_variables["episode-indices"]

        # Assert indices 0 before insert.
        self.assertEqual(size_value, 0)
        self.assertEqual(index_value, 0)
        self.assertEqual(num_episodes_value, 0)
        self.assertEqual(np.sum(episode_index_values), 0)

        # Insert one more element than capacity. Note: this is different than
        # replay test because due to episode semantics, it matters if
        # these are terminal or not. This tests if episode index updating
        # causes problems if none of the inserted elements are terminal.
        observation = non_terminal_records(self.record_space,
                                           self.capacity + 1)
        test.test(("insert_records", observation), expected_outputs=None)

        ring_buffer_variables = test.get_variable_values(
            ring_buffer, self.ring_buffer_variables)
        size_value = ring_buffer_variables["size"]
        index_value = ring_buffer_variables["index"]
        num_episodes_value = ring_buffer_variables["num-episodes"]
        episode_index_values = ring_buffer_variables["episode-indices"]

        # Size should be equivalent to capacity when full.
        self.assertEqual(size_value, self.capacity)

        # Index should be one over capacity due to modulo.
        self.assertEqual(index_value, 1)
        self.assertEqual(num_episodes_value, 0)
        self.assertEqual(np.sum(episode_index_values), 0)

        # If we fetch n elements, we expect to see exactly the last n.
        for last_n in range(1, 6):
            batch = test.test(("get_records", last_n), expected_outputs=None)
            recursive_assert_almost_equal(
                batch["actions"]["action1"],
                observation["actions"]["action1"][-last_n:])
            recursive_assert_almost_equal(
                batch["states"]["state2"],
                observation["states"]["state2"][-last_n:])
            recursive_assert_almost_equal(batch["terminals"],
                                          observation["terminals"][-last_n:])
예제 #6
0
    def test_episode_fetching(self):
        """
        Test if we can accurately fetch most recent episodes.
        """
        ring_buffer = RingBuffer(capacity=self.capacity,
                                 episode_semantics=True)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=self.input_spaces)
        # Insert 2 non-terminals, 1 terminal
        observation = non_terminal_records(self.record_space, 2)
        test.test(("insert_records", observation), expected_outputs=None)
        observation = terminal_records(self.record_space, 1)
        test.test(("insert_records", observation), expected_outputs=None)

        # We should now be able to retrieve one episode of length 3.
        episode = test.test(("get_episodes", 1), expected_outputs=None)
        self.assertTrue(len(episode['reward']) == 2)

        # We should not be able to retrieve two episodes, and still return just one.
        episode = test.test(("get_episodes", 2), expected_outputs=None)
        self.assertTrue(len(episode['reward']) == 2)

        # Insert 7 non-terminals, 1 terminal -> last terminal is now at buffer index 0 as
        # we inserted 3 + 8 = 11 elements in total.
        observation = non_terminal_records(self.record_space, 7)
        test.test(("insert_records", observation), expected_outputs=None)
        observation = terminal_records(self.record_space, 1)
        test.test(("insert_records", observation), expected_outputs=None)

        # Check if we can fetch 2 episodes:
        episodes = test.test(("get_episodes", 2), expected_outputs=None)

        # We now expect to have retrieved:
        # - 10 time steps
        # - 2 terminal values 1
        # - Terminal values spaced apart 1 index due to the insertion order
        self.assertEqual(len(episodes['terminals']), self.capacity)
        self.assertEqual(episodes['terminals'][0], True)
        self.assertEqual(episodes['terminals'][2], True)
예제 #7
0
    def test_capacity_with_episodes(self):
        """
        Tests if inserts of non-terminals work when turning
        on episode semantics.

        Note that this does not test episode semantics itself, which are tested below.
        """
        ring_buffer = RingBuffer(capacity=self.capacity,
                                 episode_semantics=True)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=self.input_spaces)
        # Internal memory variables.
        ring_buffer_variables = ring_buffer.get_variables(
            self.ring_buffer_variables, global_scope=False)
        buffer_size = ring_buffer_variables["size"]
        buffer_index = ring_buffer_variables["index"]
        num_episodes = ring_buffer_variables["num-episodes"]
        episode_indices = ring_buffer_variables["episode-indices"]

        size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values(
            buffer_size, buffer_index, num_episodes, episode_indices)

        # Assert indices 0 before insert.
        self.assertEqual(size_value, 0)
        self.assertEqual(index_value, 0)
        self.assertEqual(num_episodes_value, 0)
        self.assertEqual(np.sum(episode_index_values), 0)

        # Insert one more element than capacity. Note: this is different than
        # replay test because due to episode semantics, it matters if
        # these are terminal or not. This tests if episode index updating
        # causes problems if none of the inserted elements are terminal.
        observation = non_terminal_records(self.record_space,
                                           self.capacity + 1)
        test.test(("insert_records", observation), expected_outputs=None)
        size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values(
            buffer_size, buffer_index, num_episodes, episode_indices)

        # Size should be equivalent to capacity when full.
        self.assertEqual(size_value, self.capacity)

        # Index should be one over capacity due to modulo.
        self.assertEqual(index_value, 1)
        self.assertEqual(num_episodes_value, 0)
        self.assertEqual(np.sum(episode_index_values), 0)
예제 #8
0
    def test_update_records(self):
        """
        Tests update records logic.
        """
        memory = PrioritizedReplay(capacity=self.capacity)
        test = ComponentTest(component=memory, input_spaces=self.input_spaces)

        # Insert a few Elements.
        observation = non_terminal_records(self.record_space, 2)
        test.test(("insert_records", observation), expected_outputs=None)

        # Fetch elements and their indices.
        num_records = 2
        batch = test.test(("get_records", num_records), expected_outputs=None)
        indices = batch[1]
        self.assertEqual(num_records, len(indices))
        # 0.3, 0.5, 1.0])
        input_params = [indices, np.asarray([0.1, 0.2])]
        # Does not return anything
        test.test(("update_records", input_params), expected_outputs=None)
예제 #9
0
    def test_episode_indices_when_inserting(self):
        """
        Tests if episodes indices and counts are set correctly when inserting
        terminals.
        """
        ring_buffer = RingBuffer(capacity=self.capacity,
                                 episode_semantics=True)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=self.input_spaces)
        # Internal memory variables.
        ring_buffer_variables = ring_buffer.get_variables(
            self.ring_buffer_variables, global_scope=False)
        buffer_size = ring_buffer_variables["size"]
        buffer_index = ring_buffer_variables["index"]
        num_episodes = ring_buffer_variables["num-episodes"]
        episode_indices = ring_buffer_variables["episode-indices"]

        # First, we insert a single terminal record.
        observation = terminal_records(self.record_space, 1)
        test.test(("insert_records", observation), expected_outputs=None)
        size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values(
            buffer_size, buffer_index, num_episodes, episode_indices)

        # One episode should be present.
        self.assertEqual(num_episodes_value, 1)
        # However, the index of that episode is 0, so we cannot fetch it.
        self.assertEqual(sum(episode_index_values), 0)

        # Next, we insert 1 non-terminal, then 1 terminal element.
        observation = non_terminal_records(self.record_space, 1)
        test.test(("insert_records", observation), expected_outputs=None)
        observation = terminal_records(self.record_space, 1)
        test.test(("insert_records", observation), expected_outputs=None)

        # Now, we expect to have 2 episodes with episode indices at 0 and 2.
        size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values(
            buffer_size, buffer_index, num_episodes, episode_indices)
        print('Episode indices after = {}'.format(episode_index_values))
        self.assertEqual(num_episodes_value, 2)
        self.assertEqual(episode_index_values[1], 2)
예제 #10
0
    def test_episode_indices_when_inserting(self):
        """
        Tests if episodes indices and counts are set correctly when inserting
        terminals.
        """
        for backend in (None, "python"):
            ring_buffer = RingBuffer(capacity=self.capacity, backend=backend)
            test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces)

            # First, we insert a single terminal record.
            observation = terminal_records(self.record_space, 1)
            test.test(("insert_records", observation), expected_outputs=None)

            # Internal memory variables.
            ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables)
            num_episodes_value = ring_buffer_variables["num-episodes"]
            episode_index_values = ring_buffer_variables["episode-indices"]

            # One episode should be present.
            self.assertEqual(num_episodes_value, 1)
            # However, the index of that episode is 0, so we cannot fetch it.
            self.assertEqual(sum(episode_index_values), 0)

            # Next, we insert 1 non-terminal, then 1 terminal element.
            observation = non_terminal_records(self.record_space, 1)
            test.test(("insert_records", observation), expected_outputs=None)
            observation = terminal_records(self.record_space, 1)
            test.test(("insert_records", observation), expected_outputs=None)

            # Now, we expect to have 2 episodes with episode indices at 0 and 2.
            ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables)
            num_episodes_value = ring_buffer_variables["num-episodes"]
            episode_index_values = ring_buffer_variables["episode-indices"]

            print('Episode indices after = {}'.format(episode_index_values))
            self.assertEqual(num_episodes_value, 2)
            self.assertEqual(episode_index_values[1], 2)
예제 #11
0
    def test_episode_fetching(self):
        """
        Test if we can accurately fetch most recent episodes.
        """
        for backend in (None, "python"):
            ring_buffer = RingBuffer(capacity=self.capacity, backend=backend)
            test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces)

            # Insert 2 non-terminals, 1 terminal
            observation = non_terminal_records(self.record_space, 2)
            test.test(("insert_records", observation), expected_outputs=None)
            observation = terminal_records(self.record_space, 1)
            test.test(("insert_records", observation), expected_outputs=None)

            ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables)
            num_episodes_value = ring_buffer_variables["num-episodes"]
            episode_index_values = ring_buffer_variables["episode-indices"]

            # One episode.
            self.assertEqual(num_episodes_value, 1)
            expected_indices = [0] * self.capacity
            expected_indices[0] = 2
            recursive_assert_almost_equal(episode_index_values, expected_indices)

            # We should now be able to retrieve one episode of length 3.
            episode = test.test(("get_episodes", 1), expected_outputs=None)
            expected_terminals = [0, 0, 1]
            recursive_assert_almost_equal(episode["terminals"], expected_terminals)

            # We should not be able to retrieve two episodes, and still return just one.
            episode = test.test(("get_episodes", 2), expected_outputs=None)
            expected_terminals = [0, 0, 1]
            recursive_assert_almost_equal(episode["terminals"], expected_terminals)

            # Insert 7 non-terminals.
            observation = non_terminal_records(self.record_space, 7)
            test.test(("insert_records", observation), expected_outputs=None)

            ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables)
            index_value = ring_buffer_variables["index"]
            episode_index_values = ring_buffer_variables["episode-indices"]

            # Episode indices should not have changed.
            expected_indices[0] = 2
            recursive_assert_almost_equal(episode_index_values, expected_indices)
            # Inserted 2 non-terminal, 1 terminal, 7 non-terminal at capacity 10 -> should be at 0 again.
            self.assertEqual(index_value, 0)

            # Now inserting one terminal so the terminal buffer has layout [1 0 1 0 0 0 0 0 0 0]
            observation = terminal_records(self.record_space, 1)
            test.test(("insert_records", observation), expected_outputs=None)

            # Episode indices:
            ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables)
            num_episodes_value = ring_buffer_variables["num-episodes"]
            recursive_assert_almost_equal(num_episodes_value, 2)

            # # Check if we can fetch 2 episodes:
            episodes = test.test(("get_episodes", 2), expected_outputs=None)
            #
            # # We now expect to have retrieved:
            # # - 10 time steps
            # # - 2 terminal values 1
            # # - Terminal values spaced apart 1 index due to the insertion order
            self.assertEqual(len(episodes['terminals']), self.capacity)
            self.assertEqual(episodes['terminals'][0], True)
            self.assertEqual(episodes['terminals'][2], True)
예제 #12
0
    def test_segment_tree_insert_values(self):
        """
        Tests if segment tree inserts into correct positions.
        """
        memory = PrioritizedReplay(capacity=self.capacity,
                                   alpha=self.alpha,
                                   beta=self.beta)
        test = ComponentTest(component=memory, input_spaces=self.input_spaces)
        priority_capacity = 1
        while priority_capacity < self.capacity:
            priority_capacity *= 2

        memory_variables = memory.get_variables(
            ["sum-segment-tree", "min-segment-tree"], global_scope=False)
        sum_segment_tree = memory_variables['sum-segment-tree']
        min_segment_tree = memory_variables['min-segment-tree']
        sum_segment_values, min_segment_values = test.read_variable_values(
            sum_segment_tree, min_segment_tree)

        self.assertEqual(sum(sum_segment_values), 0)
        self.assertEqual(sum(min_segment_values), float('inf'))
        self.assertEqual(len(sum_segment_values), 2 * priority_capacity)
        self.assertEqual(len(min_segment_values), 2 * priority_capacity)
        # Insert 1 Element.
        observation = non_terminal_records(self.record_space, 1)
        test.test(("insert_records", observation), expected_outputs=None)

        # Fetch segment tree.
        sum_segment_values, min_segment_values = test.read_variable_values(
            sum_segment_tree, min_segment_tree)

        # Check insert positions
        # Initial insert is at priority capacity
        print(sum_segment_values)
        print(min_segment_values)
        start = priority_capacity

        while start >= 1:
            self.assertEqual(sum_segment_values[start], 1.0)
            self.assertEqual(min_segment_values[start], 1.0)
            start = int(start / 2)

        # Insert another Element.
        observation = non_terminal_records(self.record_space, 1)
        test.test(("insert_records", observation), expected_outputs=None)

        # Fetch segment tree.
        sum_segment_values, min_segment_values = test.read_variable_values(
            sum_segment_tree, min_segment_tree)
        print(sum_segment_values)
        print(min_segment_values)

        # Index shifted 1
        start = priority_capacity + 1
        self.assertEqual(sum_segment_values[start], 1.0)
        self.assertEqual(min_segment_values[start], 1.0)
        start = int(start / 2)
        while start >= 1:
            # 1 + 1 is 2 on the segment.
            self.assertEqual(sum_segment_values[start], 2.0)
            # min is still 1.
            self.assertEqual(min_segment_values[start], 1.0)
            start = int(start / 2)