Пример #1
0
    def test_episode_indices_when_inserting(self):
        """
        Tests if episodes indices and counts are set correctly when inserting
        terminals.
        """
        ring_buffer = RingBuffer(capacity=self.capacity,
                                 episode_semantics=True)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=self.input_spaces)
        # Internal memory variables.
        ring_buffer_variables = ring_buffer.get_variables(
            self.ring_buffer_variables, global_scope=False)
        buffer_size = ring_buffer_variables["size"]
        buffer_index = ring_buffer_variables["index"]
        num_episodes = ring_buffer_variables["num-episodes"]
        episode_indices = ring_buffer_variables["episode-indices"]

        # First, we insert a single terminal record.
        observation = terminal_records(self.record_space, 1)
        test.test(("insert_records", observation), expected_outputs=None)
        size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values(
            buffer_size, buffer_index, num_episodes, episode_indices)

        # One episode should be present.
        self.assertEqual(num_episodes_value, 1)
        # However, the index of that episode is 0, so we cannot fetch it.
        self.assertEqual(sum(episode_index_values), 0)

        # Next, we insert 1 non-terminal, then 1 terminal element.
        observation = non_terminal_records(self.record_space, 1)
        test.test(("insert_records", observation), expected_outputs=None)
        observation = terminal_records(self.record_space, 1)
        test.test(("insert_records", observation), expected_outputs=None)

        # Now, we expect to have 2 episodes with episode indices at 0 and 2.
        size_value, index_value, num_episodes_value, episode_index_values = test.read_variable_values(
            buffer_size, buffer_index, num_episodes, episode_indices)
        print('Episode indices after = {}'.format(episode_index_values))
        self.assertEqual(num_episodes_value, 2)
        self.assertEqual(episode_index_values[1], 2)
Пример #2
0
    def test_episode_fetching(self):
        """
        Test if we can accurately fetch most recent episodes.
        """
        ring_buffer = RingBuffer(capacity=self.capacity,
                                 episode_semantics=True)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=self.input_spaces)
        # Insert 2 non-terminals, 1 terminal
        observation = non_terminal_records(self.record_space, 2)
        test.test(("insert_records", observation), expected_outputs=None)
        observation = terminal_records(self.record_space, 1)
        test.test(("insert_records", observation), expected_outputs=None)

        # We should now be able to retrieve one episode of length 3.
        episode = test.test(("get_episodes", 1), expected_outputs=None)
        self.assertTrue(len(episode['reward']) == 2)

        # We should not be able to retrieve two episodes, and still return just one.
        episode = test.test(("get_episodes", 2), expected_outputs=None)
        self.assertTrue(len(episode['reward']) == 2)

        # Insert 7 non-terminals, 1 terminal -> last terminal is now at buffer index 0 as
        # we inserted 3 + 8 = 11 elements in total.
        observation = non_terminal_records(self.record_space, 7)
        test.test(("insert_records", observation), expected_outputs=None)
        observation = terminal_records(self.record_space, 1)
        test.test(("insert_records", observation), expected_outputs=None)

        # Check if we can fetch 2 episodes:
        episodes = test.test(("get_episodes", 2), expected_outputs=None)

        # We now expect to have retrieved:
        # - 10 time steps
        # - 2 terminal values 1
        # - Terminal values spaced apart 1 index due to the insertion order
        self.assertEqual(len(episodes['terminals']), self.capacity)
        self.assertEqual(episodes['terminals'][0], True)
        self.assertEqual(episodes['terminals'][2], True)
Пример #3
0
    def test_episode_indices_when_inserting(self):
        """
        Tests if episodes indices and counts are set correctly when inserting
        terminals.
        """
        for backend in (None, "python"):
            ring_buffer = RingBuffer(capacity=self.capacity, backend=backend)
            test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces)

            # First, we insert a single terminal record.
            observation = terminal_records(self.record_space, 1)
            test.test(("insert_records", observation), expected_outputs=None)

            # Internal memory variables.
            ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables)
            num_episodes_value = ring_buffer_variables["num-episodes"]
            episode_index_values = ring_buffer_variables["episode-indices"]

            # One episode should be present.
            self.assertEqual(num_episodes_value, 1)
            # However, the index of that episode is 0, so we cannot fetch it.
            self.assertEqual(sum(episode_index_values), 0)

            # Next, we insert 1 non-terminal, then 1 terminal element.
            observation = non_terminal_records(self.record_space, 1)
            test.test(("insert_records", observation), expected_outputs=None)
            observation = terminal_records(self.record_space, 1)
            test.test(("insert_records", observation), expected_outputs=None)

            # Now, we expect to have 2 episodes with episode indices at 0 and 2.
            ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables)
            num_episodes_value = ring_buffer_variables["num-episodes"]
            episode_index_values = ring_buffer_variables["episode-indices"]

            print('Episode indices after = {}'.format(episode_index_values))
            self.assertEqual(num_episodes_value, 2)
            self.assertEqual(episode_index_values[1], 2)
Пример #4
0
    def test_only_terminal_with_episodes(self):
        """
        Edge case: What if only terminals are inserted when episode
        semantics are enabled?
        """
        ring_buffer = RingBuffer(capacity=self.capacity)
        test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces)
        ring_buffer_variables = ring_buffer.get_variables(self.ring_buffer_variables, global_scope=False)
        num_episodes = ring_buffer_variables["num-episodes"]
        episode_indices = ring_buffer_variables["episode-indices"]

        observation = terminal_records(self.record_space, self.capacity)
        test.test(("insert_records", observation), expected_outputs=None)
        num_episodes_value, episode_index_values = test.read_variable_values(num_episodes, episode_indices)
        self.assertEqual(num_episodes_value, self.capacity)
        # Every episode index should correspond to its position
        for i in range_(self.capacity):
            self.assertEqual(episode_index_values[i], i)
Пример #5
0
    def test_episode_fetching(self):
        """
        Test if we can accurately fetch most recent episodes.
        """
        for backend in (None, "python"):
            ring_buffer = RingBuffer(capacity=self.capacity, backend=backend)
            test = ComponentTest(component=ring_buffer, input_spaces=self.input_spaces)

            # Insert 2 non-terminals, 1 terminal
            observation = non_terminal_records(self.record_space, 2)
            test.test(("insert_records", observation), expected_outputs=None)
            observation = terminal_records(self.record_space, 1)
            test.test(("insert_records", observation), expected_outputs=None)

            ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables)
            num_episodes_value = ring_buffer_variables["num-episodes"]
            episode_index_values = ring_buffer_variables["episode-indices"]

            # One episode.
            self.assertEqual(num_episodes_value, 1)
            expected_indices = [0] * self.capacity
            expected_indices[0] = 2
            recursive_assert_almost_equal(episode_index_values, expected_indices)

            # We should now be able to retrieve one episode of length 3.
            episode = test.test(("get_episodes", 1), expected_outputs=None)
            expected_terminals = [0, 0, 1]
            recursive_assert_almost_equal(episode["terminals"], expected_terminals)

            # We should not be able to retrieve two episodes, and still return just one.
            episode = test.test(("get_episodes", 2), expected_outputs=None)
            expected_terminals = [0, 0, 1]
            recursive_assert_almost_equal(episode["terminals"], expected_terminals)

            # Insert 7 non-terminals.
            observation = non_terminal_records(self.record_space, 7)
            test.test(("insert_records", observation), expected_outputs=None)

            ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables)
            index_value = ring_buffer_variables["index"]
            episode_index_values = ring_buffer_variables["episode-indices"]

            # Episode indices should not have changed.
            expected_indices[0] = 2
            recursive_assert_almost_equal(episode_index_values, expected_indices)
            # Inserted 2 non-terminal, 1 terminal, 7 non-terminal at capacity 10 -> should be at 0 again.
            self.assertEqual(index_value, 0)

            # Now inserting one terminal so the terminal buffer has layout [1 0 1 0 0 0 0 0 0 0]
            observation = terminal_records(self.record_space, 1)
            test.test(("insert_records", observation), expected_outputs=None)

            # Episode indices:
            ring_buffer_variables = test.get_variable_values(self.ring_buffer_variables)
            num_episodes_value = ring_buffer_variables["num-episodes"]
            recursive_assert_almost_equal(num_episodes_value, 2)

            # # Check if we can fetch 2 episodes:
            episodes = test.test(("get_episodes", 2), expected_outputs=None)
            #
            # # We now expect to have retrieved:
            # # - 10 time steps
            # # - 2 terminal values 1
            # # - Terminal values spaced apart 1 index due to the insertion order
            self.assertEqual(len(episodes['terminals']), self.capacity)
            self.assertEqual(episodes['terminals'][0], True)
            self.assertEqual(episodes['terminals'][2], True)