Exemple #1
0
    def test_get_batch(self):
        """
        Tests if the external get-batch logic returns a batch and the corresponding
        indices after inserting one.
        """
        env = RandomEnv(state_space=spaces.IntBox(2), action_space=spaces.IntBox(2), deterministic=True)
        state_space = env.state_space
        action_space = env.action_space
        agent = ApexAgent.from_spec(
            "configs/apex_agent_for_random_env.json",
            state_space=state_space,
            action_space=action_space
        )
        rewards = FloatBox()
        terminals = BoolBox()

        # Observe a few times.
        start = time.monotonic()
        agent.observe(
            states=state_space.sample(size=100),
            actions=action_space.sample(size=100),
            internals=[],
            rewards=rewards.sample(size=100),
            terminals=terminals.sample(size=100)
        )
        end = time.monotonic() - start
        print("Time to insert 100 elements {} s.".format(end))

        batch = agent.get_batch()
        print(batch)
Exemple #2
0
    def __init__(self,
                 discount=0.98,
                 memory_spec=None,
                 double_q=True,
                 dueling_q=True,
                 **kwargs):
        """
        Args:
            discount (float): The discount factor (gamma).
            memory_spec (Optional[dict,Memory]): The spec for the Memory to use for the DQN algorithm.
            double_q (bool): Whether to use the double DQN loss function (see [2]).
            dueling_q (bool): Whether to use a dueling layer in the ActionAdapter  (see [3]).
        """
        super(DQNAgent, self).__init__(**kwargs)

        self.discount = discount
        self.memory = Memory.from_spec(memory_spec)
        self.record_space = Dict(states=self.state_space,
                                 actions=self.action_space,
                                 rewards=float,
                                 terminals=BoolBox(),
                                 add_batch_rank=False)
        self.double_q = double_q
        self.dueling_q = dueling_q

        # The target policy (is synced from the q-net policy every n steps).
        self.target_policy = None

        self.policy = Policy(
            neural_network=self.neural_network,
            action_adapter_spec=dict(add_dueling_layer=self.dueling_q))
        # Copy our Policy (target-net), make target-net synchronizable.
        self.target_policy = self.policy.copy(scope="target-policy")
        self.target_policy.add_component(Synchronizable(),
                                         connections=CONNECT_ALL)

        self.merger = Merger(output_space=self.record_space)
        splitter_input_space = copy.deepcopy(self.record_space)
        splitter_input_space["next_states"] = self.state_space
        self.splitter = Splitter(input_space=splitter_input_space)
        self.loss_function = DQNLossFunction(discount=self.discount,
                                             double_q=self.double_q)

        self.assemble_meta_graph(self.preprocessor_stack, self.memory,
                                 self.merger, self.splitter, self.policy,
                                 self.target_policy, self.exploration,
                                 self.loss_function, self.optimizer)
        # markup = get_graph_markup(self.graph_builder.core_component)
        # print(markup)
        self.build_graph()
    def test_sampler_component(self):
        input_space = Dict(states=dict(state1=float, state2=float),
                           actions=dict(action1=float),
                           reward=float,
                           terminals=BoolBox(),
                           add_batch_rank=True)

        sampler = Sampler()
        test = ComponentTest(component=sampler,
                             input_spaces=dict(sample_size=int,
                                               inputs=input_space))

        samples = input_space.sample(size=100)
        sample = test.test(out_socket_names="sample",
                           inputs=dict(sample_size=10, inputs=samples),
                           expected_outputs=None)

        self.assertEqual(len(sample["actions"]["action1"]), 10)
        self.assertEqual(len(sample["states"]["state1"]), 10)
        self.assertEqual(len(sample["terminals"]), 10)

        print(sample)
Exemple #4
0
    def _assemble_meta_graph(self, core, *params):
        # Define our interface.
        core.define_inputs("states_from_env", "external_batch_states", "external_batch_next_states",
                           "states_for_memory", space=self.state_space.with_batch_rank())
        core.define_inputs("actions_for_memory", "external_batch_actions", space=self.action_space.with_batch_rank())
        core.define_inputs("rewards_for_memory", "external_batch_rewards", space=FloatBox(add_batch_rank=True))
        core.define_inputs("terminals_for_memory", "external_batch_terminals", space=BoolBox(add_batch_rank=True))

        #core.define_inputs("deterministic", space=bool)
        core.define_inputs("time_step", space=int)
        core.define_outputs("get_actions", "insert_records",
                            "update_from_memory", "update_from_external_batch",
                            "get_batch", "get_indices", "loss")


        core.add_components(self.policy)
        # Add an Exploration for the q-net (target-net doesn't need one).
        core.add_components(self.exploration)

        # Add our Memory Component plus merger and splitter.
        core.add_components(self.memory, self.merger, self.splitter)

        # Add the loss function and optimizer.
        core.add_components(self.loss_function, self.optimizer)
class TestPrioritizedReplay(unittest.TestCase):
    """
    Tests sampling and insertion behaviour of the prioritized_replay module.
    """
    record_space = Dict(
        states=dict(state1=float, state2=float),
        actions=dict(action1=float),
        reward=float,
        terminals=BoolBox(),
        add_batch_rank=True
    )
    memory_variables = ["size", "index", "max-priority"]

    capacity = 10
    alpha = 1.0
    beta = 1.0

    max_priority = 1.0

    def test_insert(self):
        """
        Simply tests insert op without checking internal logic.
        """
        memory = PrioritizedReplay(
            capacity=self.capacity,
            next_states=True,
            alpha=self.alpha,
            beta=self.beta
        )
        test = ComponentTest(component=memory, input_spaces=dict(
            records=self.record_space,
            num_records=int,
            indices=IntBox(shape=(), add_batch_rank=True),
            update=FloatBox(shape=(), add_batch_rank=True)
        ))

        observation = self.record_space.sample(size=1)
        test.test(out_socket_names="insert_records", inputs=observation, expected_outputs=None)

    def test_capacity(self):
        """
        Tests if insert correctly manages capacity.
        """
        memory = PrioritizedReplay(
            capacity=self.capacity,
            next_states=True,
            alpha=self.alpha,
            beta=self.beta
        )
        test = ComponentTest(component=memory, input_spaces=dict(
            records=self.record_space,
            num_records=int,
            indices=IntBox(shape=(), add_batch_rank=True),
            update=FloatBox(shape=(), add_batch_rank=True)
        ))

        # Internal state variables.
        memory_variables = memory.get_variables(self.memory_variables, global_scope=False)
        buffer_size = memory_variables['size']
        buffer_index = memory_variables['index']
        max_priority = memory_variables['max-priority']

        size_value, index_value, max_priority_value = test.get_variable_values(buffer_size, buffer_index, max_priority)

        # Assert indices 0 before insert.
        self.assertEqual(size_value, 0)
        self.assertEqual(index_value, 0)
        self.assertEqual(max_priority_value, 1.0)

        # Insert one more element than capacity
        observation = self.record_space.sample(size=self.capacity + 1)
        test.test(out_socket_names="insert_records", inputs=observation, expected_outputs=None)

        size_value, index_value = test.get_variable_values(buffer_size, buffer_index)
        # Size should be equivalent to capacity when full.
        self.assertEqual(size_value, self.capacity)

        # Index should be one over capacity due to modulo.
        self.assertEqual(index_value, 1)

    def test_batch_retrieve(self):
        """
        Tests if retrieval correctly manages capacity.
        """
        memory = PrioritizedReplay(
            capacity=self.capacity,
            next_states=True,
            alpha=self.alpha,
            beta=self.beta
        )
        test = ComponentTest(component=memory, input_spaces=dict(
            records=self.record_space,
            num_records=int,
            indices=IntBox(shape=(), add_batch_rank=True),
            update=FloatBox(shape=(), add_batch_rank=True)
        ))

        # Insert 2 Elements.
        observation = non_terminal_records(self.record_space, 2)
        test.test(out_socket_names="insert_records", inputs=observation, expected_outputs=None)

        # Assert we can now fetch 2 elements.
        num_records = 2
        batch = test.test(out_socket_names="get_records", inputs=num_records, expected_outputs=None)
        print('Result batch = {}'.format(batch))
        self.assertEqual(2, len(batch['terminals']))
        # Assert next states key is there
        self.assertTrue('next_states' in batch)

        # We allow repeat indices in sampling.
        num_records = 5
        batch = test.test(out_socket_names="get_records", inputs=num_records, expected_outputs=None)
        self.assertEqual(5, len(batch['terminals']))

        # Now insert over capacity, note all elements here are non-terminal.
        observation = non_terminal_records(self.record_space, self.capacity)
        test.test(out_socket_names="insert_records", inputs=observation, expected_outputs=None)

        # Assert we can fetch exactly capacity elements.
        num_records = self.capacity
        batch = test.test(out_socket_names="get_records", inputs=num_records, expected_outputs=None)
        self.assertEqual(self.capacity, len(batch['terminals']))

    def test_without_next_state(self):
        """
        Tests retrieval works if next state option is deactivated and
        that no next_states key is present.
        """
        memory = PrioritizedReplay(
            capacity=self.capacity,
            next_states=False
        )
        test = ComponentTest(component=memory, input_spaces=dict(
            records=self.record_space,
            num_records=int,
            indices=IntBox(shape=(), add_batch_rank=True),
            update=FloatBox(shape=(), add_batch_rank=True)
        ))

        # Insert 2 Elements.
        observation = non_terminal_records(self.record_space, 2)
        test.test(out_socket_names="insert_records", inputs=observation, expected_outputs=None)

        # Assert we can now fetch 2 elements.
        num_records = 2
        batch = test.test(out_socket_names="get_records", inputs=num_records, expected_outputs=None)
        self.assertTrue('next_states' not in batch)

    def test_update_records(self):
        """
        Tests update records logic.
        """
        memory = PrioritizedReplay(
            capacity=self.capacity,
            next_states=True
        )
        test = ComponentTest(component=memory, input_spaces=dict(
            records=self.record_space,
            num_records=int,
            indices=IntBox(shape=(), add_batch_rank=True),
            update=FloatBox(shape=(), add_batch_rank=True)
        ))

        # Insert a few Elements.
        observation = non_terminal_records(self.record_space, 5)
        test.test(out_socket_names="insert_records", inputs=observation, expected_outputs=None)

        # Fetch elements and their indices.
        num_records = 5
        batch = test.test(
            out_socket_names=["get_records", "record_indices"],
            inputs=dict(num_records=num_records),
            expected_outputs=None
        )
        indices = batch[1]
        self.assertEqual(num_records, len(indices))

        input_params = dict(
            indices=indices,
            update=np.asarray([0.1, 0.2, 0.3, 0.5, 1.0])
        )
        # Does not return anything
        test.test(out_socket_names=["update_records"], inputs=input_params, expected_outputs=None)

    def test_segment_tree_insert_values(self):
        """
        Tests if segment tree inserts into correct positions.
        """
        memory = PrioritizedReplay(
            capacity=self.capacity,
            next_states=True,
            alpha=self.alpha,
            beta=self.beta
        )
        test = ComponentTest(component=memory, input_spaces=dict(
            records=self.record_space,
            num_records=int,
            indices=IntBox(shape=(), add_batch_rank=True),
            update=FloatBox(shape=(), add_batch_rank=True)
        ))
        priority_capacity = 1
        while priority_capacity < self.capacity:
            priority_capacity *= 2

        memory_variables = memory.get_variables(["sum-segment-tree", "min-segment-tree"], global_scope=False)
        sum_segment_tree = memory_variables['sum-segment-tree']
        min_segment_tree = memory_variables['min-segment-tree']
        sum_segment_values, min_segment_values = test.get_variable_values(sum_segment_tree, min_segment_tree)

        self.assertEqual(sum(sum_segment_values), 0)
        self.assertEqual(sum(min_segment_values), float('inf'))
        self.assertEqual(len(sum_segment_values), 2 * priority_capacity)
        self.assertEqual(len(min_segment_values), 2 * priority_capacity)
        # Insert 1 Element.
        observation = non_terminal_records(self.record_space, 1)
        test.test(out_socket_names="insert_records", inputs=observation, expected_outputs=None)

        # Fetch segment tree.
        sum_segment_values, min_segment_values = test.get_variable_values(sum_segment_tree, min_segment_tree)

        # Check insert positions
        # Initial insert is at priority capacity
        print(sum_segment_values)
        print(min_segment_values)
        start = priority_capacity

        while start >= 1:
            self.assertEqual(sum_segment_values[start], 1.0)
            self.assertEqual(min_segment_values[start], 1.0)
            start = int(start / 2)

        # Insert another Element.
        observation = non_terminal_records(self.record_space, 1)
        test.test(out_socket_names="insert_records", inputs=observation, expected_outputs=None)

        # Fetch segment tree.
        sum_segment_values, min_segment_values = test.get_variable_values(sum_segment_tree, min_segment_tree)
        print(sum_segment_values)
        print(min_segment_values)

        # Index shifted 1
        start = priority_capacity + 1
        self.assertEqual(sum_segment_values[start], 1.0)
        self.assertEqual(min_segment_values[start], 1.0)
        start = int(start / 2)
        while start >= 1:
            # 1 + 1 is 2 on the segment.
            self.assertEqual(sum_segment_values[start], 2.0)
            # min is still 1.
            self.assertEqual(min_segment_values[start], 1.0)
            start = int(start / 2)
Exemple #6
0
    def _assemble_meta_graph(self, core, *params):
        # Define our interface.
        core.define_inputs("states_from_env",
                           "external_batch_states",
                           "external_batch_next_states",
                           "states_for_memory",
                           space=self.state_space.with_batch_rank())
        core.define_inputs("actions_for_memory",
                           "external_batch_actions",
                           space=self.action_space.with_batch_rank())
        core.define_inputs("rewards_for_memory",
                           "external_batch_rewards",
                           space=FloatBox(add_batch_rank=True))
        core.define_inputs("terminals_for_memory",
                           "external_batch_terminals",
                           space=BoolBox(add_batch_rank=True))

        #core.define_inputs("deterministic", space=bool)
        core.define_inputs("time_step", space=int)
        core.define_outputs("get_actions", "insert_records",
                            "update_from_memory", "update_from_external_batch",
                            "sync_target_qnet", "get_batch", "loss")

        # Add the Q-net, copy it (target-net) and add the target-net.
        self.target_policy = self.policy.copy(scope="target-policy")
        # Make target_policy writable
        self.target_policy.add_component(Synchronizable(),
                                         connections=CONNECT_ALL)
        core.add_components(self.policy, self.target_policy)
        # Add an Exploration for the q-net (target-net doesn't need one).
        core.add_components(self.exploration)

        # Add our Memory Component plus merger and splitter.
        core.add_components(self.memory, self.merger, self.splitter)

        # Add the loss function and optimizer.
        core.add_components(self.loss_function, self.optimizer)

        # Now connect everything ...

        # All external/env states into preprocessor (memory already preprocessed).
        core.connect("states_from_env", (self.preprocessor_stack, "input"),
                     label="env,s")
        core.connect("external_batch_states",
                     (self.preprocessor_stack, "input"),
                     label="ext,s")
        core.connect("external_batch_next_states",
                     (self.preprocessor_stack, "input"),
                     label="ext,sp")
        core.connect((self.preprocessor_stack, "output"),
                     (self.policy, "nn_input"),
                     label="s" + (",sp" if self.double_q else ""))

        # Timestep into Exploration.
        core.connect("time_step", (self.exploration, "time_step"))

        # Policy output into Exploration -> into "actions".
        core.connect((self.policy, "sample_deterministic"),
                     (self.exploration, "sample_deterministic"),
                     label="env")
        core.connect((self.policy, "sample_stochastic"),
                     (self.exploration, "sample_stochastic"),
                     label="env")
        core.connect((self.exploration, "action"), "get_actions")
        #core.connect((self.exploration, "do_explore"), "do_explore")

        # Insert records into memory via merger.
        core.connect("states_for_memory", (self.preprocessor_stack, "input"),
                     label="to_mem")
        core.connect((self.preprocessor_stack, "output"),
                     (self.merger, "/states"),
                     label="to_mem")
        for in_ in ["actions", "rewards", "terminals"]:
            core.connect(in_ + "_for_memory", (self.merger, "/" + in_))
        core.connect((self.merger, "output"), (self.memory, "records"))
        core.connect((self.memory, "insert_records"), "insert_records")

        # Learn from Memory via get_batch and Splitter.
        core.connect(self.update_spec["batch_size"],
                     (self.memory, "num_records"))
        core.connect((self.memory, "get_records"), (self.splitter, "input"),
                     label="mem")
        core.connect((self.memory, "get_records"), "get_batch")
        core.connect((self.splitter, "/states"), (self.policy, "nn_input"),
                     label="mem,s")
        core.connect((self.splitter, "/actions"),
                     (self.loss_function, "actions"))
        core.connect((self.splitter, "/rewards"),
                     (self.loss_function, "rewards"))
        core.connect((self.splitter, "/terminals"),
                     (self.loss_function, "terminals"))
        core.connect((self.splitter, "/next_states"),
                     (self.target_policy, "nn_input"),
                     label="mem,sp")
        if self.double_q:
            core.connect((self.splitter, "/next_states"),
                         (self.policy, "nn_input"),
                         label="mem,sp")

        # Only send ext and mem labelled ops into loss function.
        q_values_socket = "q_values" if self.dueling_q is True else "action_layer_output_reshaped"
        core.connect((self.policy, q_values_socket),
                     (self.loss_function, "q_values"),
                     label="ext,mem,s")
        #core.connect((self.policy, q_values_socket), "q_values")
        core.connect((self.target_policy, q_values_socket),
                     (self.loss_function, "qt_values_s_"),
                     label="ext,mem")
        if self.double_q:
            core.connect((self.policy, q_values_socket),
                         (self.loss_function, "q_values_s_"),
                         label="ext,mem,sp")

        # Connect the Optimizer.
        core.connect((self.loss_function, "loss"), (self.optimizer, "loss"))
        core.connect((self.loss_function, "loss"), "loss")
        core.connect((self.policy, "_variables"), (self.optimizer, "vars"))
        core.connect((self.optimizer, "step"),
                     "update_from_memory",
                     label="mem")
        core.connect((self.optimizer, "step"),
                     "update_from_external_batch",
                     label="ext")

        # Add syncing capability for target-net.
        core.connect((self.policy, "_variables"),
                     (self.target_policy, "_values"))
        core.connect((self.target_policy, "sync"), "sync_target_qnet")
Exemple #7
0
    def _assemble_meta_graph_test(self, core, preprocessor, memory, merger,
                                  splitter, policy, target_policy, exploration,
                                  loss_function, optimizer):
        # Define our Spaces.
        state_space = self.state_space.with_batch_rank()
        action_space = self.action_space.with_batch_rank()
        reward_space = FloatBox(add_batch_rank=True)
        terminal_space = BoolBox(add_batch_rank=True)

        # Define our inputs.
        inputs = dict(
            states_from_env=state_space,
            states_to_memory=state_space,
            actions_to_memory=action_space,
            rewards_to_memory=reward_space,
            terminals_to_memory=terminal_space,
            states_from_external=state_space,
            next_states_from_external=state_space,
            actions_from_external=action_space,
            rewards_from_external=reward_space,
            terminals_from_external=terminal_space,
            time_step=bool,
        )
        core.define_inputs(inputs)
        # Add all sub-components.
        core.add_components(preprocessor, memory, merger, splitter, policy,
                            target_policy, exploration, loss_function,
                            optimizer)

        # Env pathway.
        preprocessed_states_from_env = preprocessor("states_from_env")
        sample_deterministic, sample_stochastic = policy(
            preprocessed_states_from_env,
            ["sample_deterministic", "sample_stochastic"])
        action = exploration(
            ["time_step", sample_deterministic, sample_stochastic])
        core.define_outputs("get_actions", action)

        # Insert into memory pathway.
        preprocessed_states_to_mem = preprocessor("states_to_memory")
        records = merger([
            preprocessed_states_to_mem, "actions_to_memory",
            "rewards_to_memory", "terminals_to_memory"
        ])
        insert_records_op = memory(records, "insert_records")
        core.define_outputs("insert_records", insert_records_op)

        # Syncing target-net.
        policy_vars = policy(None, "_variables")
        sync_op = target_policy(policy_vars, "sync")
        core.define_outputs("sync_target_qnet", sync_op)

        # Learn from memory.
        q_values_socket_name = "q_values" if self.dueling_q is True else "action_layer_output_reshaped"
        records_from_memory = memory(self.update_spec["batch_size"],
                                     "get_records")
        s_mem, a_mem, r_mem, t_mem, sp_mem = splitter(records_from_memory)
        q_values_s = policy(s_mem, q_values_socket_name)
        qt_values_sp = target_policy(sp_mem, q_values_socket_name)
        if self.double_q:
            q_values_sp = policy(sp_mem, q_values_socket_name)
            loss_per_item = loss_function(
                [q_values_s, a_mem, r_mem, t_mem, qt_values_sp, q_values_sp],
                "loss_per_item")
        else:
            loss_per_item = loss_function(
                [q_values_s, a_mem, r_mem, t_mem, qt_values_sp],
                "loss_per_item")
        update_from_mem = optimizer(loss_per_item)
        optimizer(policy_vars, None)  # TODO: this will probably not work
        core.define_outputs("update_from_memory", update_from_mem)

        # Learn from external batch.
        preprocessed_s_from_external = preprocessor("states_from_external")
        preprocessed_sp_from_external = preprocessor(
            "next_states_from_external")
        q_values_s = policy(preprocessed_s_from_external, q_values_socket_name)
        qt_values_sp = target_policy(preprocessed_sp_from_external,
                                     q_values_socket_name)
        if self.double_q:
            q_values_sp = policy(preprocessed_sp_from_external,
                                 q_values_socket_name)
            loss_per_item = loss_function([
                q_values_s, "actions_from_external", "rewards_from_external",
                "terminals_from_external", qt_values_sp, q_values_sp
            ], "loss_per_item")
        else:
            loss_per_item = loss_function([
                q_values_s, "actions_from_external", "rewards_from_external",
                "terminals_from_external", qt_values_sp
            ], "loss_per_item")
        update_from_external = optimizer(loss_per_item)
        core.define_outputs("update_from_external_batch", update_from_external)
Exemple #8
0
class TestRingBufferMemory(unittest.TestCase):
    """
    Tests the ring buffer. The ring buffer has very similar tests to
    the replay memory as it supports similar insertion and retrieval semantics,
    but needs additional tests on episode indexing and its latest semantics.
    """

    record_space = Dict(states=dict(state1=float, state2=float),
                        actions=dict(action1=float),
                        reward=float,
                        terminals=BoolBox(),
                        add_batch_rank=True)
    # Generic memory variables.
    memory_variables = ["size", "index"]

    # Ring buffer variables
    ring_buffer_variables = [
        "size", "index", "num-episodes", "episode-indices"
    ]
    capacity = 10

    def test_insert_no_episodes(self):
        """
        Simply tests insert op without checking internal logic, episode
        semantics disabled.
        """
        ring_buffer = RingBuffer(capacity=self.capacity,
                                 episode_semantics=False)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=dict(records=self.record_space,
                                               num_records=int))

        observation = self.record_space.sample(size=1)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        observation = self.record_space.sample(size=100)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

    def test_capacity_no_episodes(self):
        """
        Tests if insert correctly manages capacity, no episode indices updated..
        """
        ring_buffer = RingBuffer(capacity=self.capacity,
                                 episode_semantics=False)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=dict(records=self.record_space,
                                               num_records=int))
        # Internal state variables.
        memory_variables = ring_buffer.get_variables(self.memory_variables,
                                                     global_scope=False)
        buffer_size = memory_variables['size']
        buffer_index = memory_variables['index']
        size_value, index_value = test.get_variable_values(
            buffer_size, buffer_index)

        # Assert indices 0 before insert.
        self.assertEqual(size_value, 0)
        self.assertEqual(index_value, 0)

        # Insert one more element than capacity
        observation = self.record_space.sample(size=self.capacity + 1)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        size_value, index_value = test.get_variable_values(
            buffer_size, buffer_index)
        # Size should be equivalent to capacity when full.
        self.assertEqual(size_value, self.capacity)

        # Index should be one over capacity due to modulo.
        self.assertEqual(index_value, 1)

    def test_capacity_with_episodes(self):
        """
        Tests if inserts of non-terminals work when turning
        on episode semantics.

        Note that this does not test episode semantics itself, which are tested below.
        """
        ring_buffer = RingBuffer(capacity=self.capacity,
                                 episode_semantics=True)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=dict(records=self.record_space,
                                               num_records=int,
                                               num_episodes=int))
        # Internal memory variables.
        ring_buffer_variables = ring_buffer.get_variables(
            self.ring_buffer_variables, global_scope=False)
        buffer_size = ring_buffer_variables["size"]
        buffer_index = ring_buffer_variables["index"]
        num_episodes = ring_buffer_variables["num-episodes"]
        episode_indices = ring_buffer_variables["episode-indices"]

        size_value, index_value, num_episodes_value, episode_index_values = test.get_variable_values(
            buffer_size, buffer_index, num_episodes, episode_indices)

        # Assert indices 0 before insert.
        self.assertEqual(size_value, 0)
        self.assertEqual(index_value, 0)
        self.assertEqual(num_episodes_value, 0)
        self.assertEqual(np.sum(episode_index_values), 0)

        # Insert one more element than capacity. Note: this is different than
        # replay test because due to episode semantics, it matters if
        # these are terminal or not. This tests if episode index updating
        # causes problems if none of the inserted elements are terminal.
        observation = non_terminal_records(self.record_space,
                                           self.capacity + 1)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)
        size_value, index_value, num_episodes_value, episode_index_values = test.get_variable_values(
            buffer_size, buffer_index, num_episodes, episode_indices)

        # Size should be equivalent to capacity when full.
        self.assertEqual(size_value, self.capacity)

        # Index should be one over capacity due to modulo.
        self.assertEqual(index_value, 1)
        self.assertEqual(num_episodes_value, 0)
        self.assertEqual(np.sum(episode_index_values), 0)

    def test_episode_indices_when_inserting(self):
        """
        Tests if episodes indices and counts are set correctly when inserting
        terminals.
        """
        ring_buffer = RingBuffer(capacity=self.capacity,
                                 episode_semantics=True)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=dict(records=self.record_space,
                                               num_records=int,
                                               num_episodes=int))
        # Internal memory variables.
        ring_buffer_variables = ring_buffer.get_variables(
            self.ring_buffer_variables, global_scope=False)
        buffer_size = ring_buffer_variables["size"]
        buffer_index = ring_buffer_variables["index"]
        num_episodes = ring_buffer_variables["num-episodes"]
        episode_indices = ring_buffer_variables["episode-indices"]

        # First, we insert a single terminal record.
        observation = terminal_records(self.record_space, 1)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)
        size_value, index_value, num_episodes_value, episode_index_values = test.get_variable_values(
            buffer_size, buffer_index, num_episodes, episode_indices)

        # One episode should be present.
        self.assertEqual(num_episodes_value, 1)
        # However, the index of that episode is 0, so we cannot fetch it.
        self.assertEqual(sum(episode_index_values), 0)

        # Next, we insert 1 non-terminal, then 1 terminal element.
        observation = non_terminal_records(self.record_space, 1)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)
        observation = terminal_records(self.record_space, 1)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        # Now, we expect to have 2 episodes with episode indices at 0 and 2.
        size_value, index_value, num_episodes_value, episode_index_values = test.get_variable_values(
            buffer_size, buffer_index, num_episodes, episode_indices)
        print('Episode indices after = {}'.format(episode_index_values))
        self.assertEqual(num_episodes_value, 2)
        self.assertEqual(episode_index_values[1], 2)

    def test_only_terminal_with_episodes(self):
        """
        Edge case: What if only terminals are inserted when episode
        semantics are enabled?
        """
        ring_buffer = RingBuffer(capacity=self.capacity,
                                 episode_semantics=True)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=dict(records=self.record_space,
                                               num_records=int,
                                               num_episodes=int))
        ring_buffer_variables = ring_buffer.get_variables(
            self.ring_buffer_variables, global_scope=False)
        num_episodes = ring_buffer_variables["num-episodes"]
        episode_indices = ring_buffer_variables["episode-indices"]

        observation = terminal_records(self.record_space, self.capacity)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)
        num_episodes_value, episode_index_values = test.get_variable_values(
            num_episodes, episode_indices)
        self.assertEqual(num_episodes_value, self.capacity)
        # Every episode index should correspond to its position
        for i in xrange(self.capacity):
            self.assertEqual(episode_index_values[i], i)

    def test_episode_fetching(self):
        """
        Test if we can accurately fetch most recent episodes.
        """
        ring_buffer = RingBuffer(capacity=self.capacity,
                                 episode_semantics=True)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=dict(records=self.record_space,
                                               num_records=int,
                                               num_episodes=int))
        # Insert 2 non-terminals, 1 terminal
        observation = non_terminal_records(self.record_space, 2)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)
        observation = terminal_records(self.record_space, 1)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        # We should now be able to retrieve one episode of length 3.
        episode = test.test(out_socket_names="get_episodes",
                            inputs=1,
                            expected_outputs=None)
        self.assertTrue(len(episode['reward']) == 2)

        # We should not be able to retrieve two episodes, and still return just one.
        episode = test.test(out_socket_names="get_episodes",
                            inputs=2,
                            expected_outputs=None)
        self.assertTrue(len(episode['reward']) == 2)

        # Insert 7 non-terminals, 1 terminal -> last terminal is now at buffer index 0 as
        # we inserted 3 + 8 = 11 elements in total.
        observation = non_terminal_records(self.record_space, 7)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)
        observation = terminal_records(self.record_space, 1)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        # Check if we can fetch 2 episodes:
        episodes = test.test(out_socket_names="get_episodes",
                             inputs=2,
                             expected_outputs=None)

        # We now expect to have retrieved:
        # - 10 time steps
        # - 2 terminal values 1
        # - Terminal values spaced apart 1 index due to the insertion order
        self.assertEqual(len(episodes['terminals']), self.capacity)
        self.assertEqual(episodes['terminals'][0], True)
        self.assertEqual(episodes['terminals'][2], True)

    def test_latest_batch(self):
        """
        Tests if we can fetch latest steps.
        """
        ring_buffer = RingBuffer(capacity=self.capacity,
                                 episode_semantics=True)
        test = ComponentTest(component=ring_buffer,
                             input_spaces=dict(records=self.record_space,
                                               num_records=int,
                                               num_episodes=int))

        # Insert 5 random elements.
        observation = non_terminal_records(self.record_space, 5)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        # First, test if the basic computation works.
        batch = test.test(out_socket_names="get_records",
                          inputs=5,
                          expected_outputs=None)
        self.assertEqual(len(batch['terminals']), 5)

        # Next, insert capacity more elements:
        observation = non_terminal_records(self.record_space, self.capacity)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        # If we now fetch capacity elements, we expect to see exactly the last 10.
        batch = test.test(out_socket_names="get_records",
                          inputs=self.capacity,
                          expected_outputs=None)

        # Assert every inserted element is contained, even if not in same order:
        retrieved_action = batch['actions']['action1']
        for action_value in observation['actions']['action1']:
            self.assertTrue(action_value in retrieved_action)
class TestReplayMemory(unittest.TestCase):
    """
    Tests sampling and insertion behaviour of the replay_memory module.
    """
    record_space = Dict(states=dict(state1=float, state2=float),
                        actions=dict(action1=float),
                        reward=float,
                        terminals=BoolBox(),
                        add_batch_rank=True)
    memory_variables = ["size", "index"]
    capacity = 10

    def test_insert(self):
        """
        Simply tests insert op without checking internal logic.
        """
        memory = ReplayMemory(capacity=self.capacity, next_states=True)
        test = ComponentTest(component=memory,
                             input_spaces=dict(records=self.record_space,
                                               num_records=int))

        observation = self.record_space.sample(size=1)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        observation = self.record_space.sample(size=100)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

    def test_capacity(self):
        """
        Tests if insert correctly manages capacity.
        """
        memory = ReplayMemory(capacity=self.capacity, next_states=True)
        test = ComponentTest(component=memory,
                             input_spaces=dict(records=self.record_space,
                                               num_records=int))
        # Internal state variables.
        memory_variables = memory.get_variables(self.memory_variables,
                                                global_scope=False)
        buffer_size = memory_variables['size']
        buffer_index = memory_variables['index']
        size_value, index_value = test.get_variable_values(
            buffer_size, buffer_index)

        # Assert indices 0 before insert.
        self.assertEqual(size_value, 0)
        self.assertEqual(index_value, 0)

        # Insert one more element than capacity
        observation = self.record_space.sample(size=self.capacity + 1)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        size_value, index_value = test.get_variable_values(
            buffer_size, buffer_index)
        # Size should be equivalent to capacity when full.
        self.assertEqual(size_value, self.capacity)

        # Index should be one over capacity due to modulo.
        self.assertEqual(index_value, 1)

    def test_batch_retrieve(self):
        """
        Tests if retrieval correctly manages capacity.
        """
        memory = ReplayMemory(capacity=self.capacity, next_states=True)
        test = ComponentTest(component=memory,
                             input_spaces=dict(records=self.record_space,
                                               num_records=int))

        # Insert 2 Elements.
        observation = non_terminal_records(self.record_space, 2)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        # Assert we can now fetch 2 elements.
        num_records = 2
        batch = test.test(out_socket_names="get_records",
                          inputs=num_records,
                          expected_outputs=None)
        print('Result batch = {}'.format(batch))
        self.assertEqual(2, len(batch['terminals']))
        # Assert next states key is there
        self.assertTrue('next_states' in batch)

        # Test duplicate sampling.
        num_records = 5
        batch = test.test(out_socket_names="get_records",
                          inputs=num_records,
                          expected_outputs=None)
        self.assertEqual(5, len(batch['terminals']))

        # Now insert over capacity.
        observation = non_terminal_records(self.record_space, self.capacity)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        # Assert we can fetch exactly capacity elements.
        num_records = self.capacity
        batch = test.test(out_socket_names="get_records",
                          inputs=num_records,
                          expected_outputs=None)
        self.assertEqual(self.capacity, len(batch['terminals']))

    def test_with_terminals_no_next_states(self):
        """
        Tests retrieval works if next state option is deactivated and
        that no next_states key is present.
        """
        memory = ReplayMemory(capacity=self.capacity, next_states=False)
        test = ComponentTest(component=memory,
                             input_spaces=dict(records=self.record_space,
                                               num_records=int))

        # Insert 2 terminal Elements.
        observation = terminal_records(self.record_space, 2)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        # Assert we can now fetch 2 elements.
        num_records = 2
        batch = test.test(out_socket_names="get_records",
                          inputs=num_records,
                          expected_outputs=None)

        # Sampled 2 elements
        self.assertEqual(num_records, len(batch['terminals']))
        # Both are terminal
        self.assertTrue(batch['terminals'][0] and batch['terminals'][1])
        # No next state key.
        self.assertTrue('next_states' not in batch)

    def test_without_next_state(self):
        """
        Tests retrieval works if next state option is deactivated and
        that no next_states key is present.
        """
        memory = ReplayMemory(capacity=self.capacity, next_states=False)
        test = ComponentTest(component=memory,
                             input_spaces=dict(records=self.record_space,
                                               num_records=int))

        # Insert 2 Elements.
        observation = non_terminal_records(self.record_space, 2)
        test.test(out_socket_names="insert_records",
                  inputs=observation,
                  expected_outputs=None)

        # Assert we can now fetch 2 elements.
        num_records = 2
        batch = test.test(out_socket_names="get_records",
                          inputs=num_records,
                          expected_outputs=None)
        self.assertTrue('next_states' not in batch)