コード例 #1
0
    def _create_action_adapters_and_distributions(self, action_space, action_adapter_spec):
        if action_space is None:
            adapter = ActionAdapter.from_spec(action_adapter_spec)
            self.action_space = adapter.action_space
            # Assert single component action space.
            assert len(self.action_space.flatten()) == 1,\
                "ERROR: Action space must not be ContainerSpace if no `action_space` is given in Policy c'tor!"
        else:
            self.action_space = Space.from_spec(action_space)

        # Figure out our Distributions.
        for i, (flat_key, action_component) in enumerate(self.action_space.flatten().items()):
            distribution = self.distributions[flat_key] = self._get_distribution(i, action_component)
            if distribution is None:
                raise RLGraphError("ERROR: `action_component` is of type {} and not allowed in {} Component!".
                                   format(type(action_space).__name__, self.name))
            action_adapter_type = distribution.get_action_adapter_type()
            # Spec dict.
            if isinstance(action_adapter_spec, dict):
                aa_spec = action_adapter_spec.get(flat_key, action_adapter_spec)
                aa_spec["type"] = action_adapter_type
                aa_spec["action_space"] = action_component
            # Simple type spec.
            elif not isinstance(action_adapter_spec, ActionAdapter):
                aa_spec = dict(type=action_adapter_type, action_space=action_component)
            # Direct object.
            else:
                aa_spec = action_adapter_spec
            self.action_adapters[flat_key] = ActionAdapter.from_spec(aa_spec, scope="action-adapter-{}".format(i))
コード例 #2
0
ファイル: policy.py プロジェクト: theSoenke/rlgraph
    def _create_action_adapters_and_distributions(self, action_space,
                                                  action_adapter_spec):
        if action_space is None:
            adapter = ActionAdapter.from_spec(action_adapter_spec)
            self.action_space = adapter.action_space
            # Assert single component action space.
            assert len(self.action_space.flatten()) == 1, \
                "ERROR: Action space must not be ContainerSpace if no `action_space` is given in Policy constructor!"
        else:
            self.action_space = Space.from_spec(action_space)

        # Figure out our Distributions.
        for i, (flat_key, action_component) in enumerate(
                self.action_space.flatten().items()):
            # Spec dict.
            if isinstance(action_adapter_spec, dict):
                aa_spec = flat_key_lookup(action_adapter_spec, flat_key,
                                          action_adapter_spec)
                aa_spec["action_space"] = action_component
            # Simple type spec.
            elif not isinstance(action_adapter_spec, ActionAdapter):
                aa_spec = dict(action_space=action_component)
            # Direct object.
            else:
                aa_spec = action_adapter_spec

            if isinstance(aa_spec, dict) and "type" not in aa_spec:
                dist_spec = get_default_distribution_from_space(
                    action_component, self.bounded_distribution_type,
                    self.discrete_distribution_type,
                    self.gumbel_softmax_temperature)

                self.distributions[flat_key] = Distribution.from_spec(
                    dist_spec, scope="{}-{}".format(dist_spec["type"], i))
                if self.distributions[flat_key] is None:
                    raise RLGraphError(
                        "ERROR: `action_component` is of type {} and not allowed in {} Component!"
                        .format(type(action_space).__name__, self.name))
                aa_spec[
                    "type"] = get_action_adapter_type_from_distribution_type(
                        type(self.distributions[flat_key]).__name__)
                self.action_adapters[flat_key] = ActionAdapter.from_spec(
                    aa_spec, scope="action-adapter-{}".format(i))
            else:
                self.action_adapters[flat_key] = ActionAdapter.from_spec(
                    aa_spec, scope="action-adapter-{}".format(i))
                dist_spec = get_distribution_spec_from_action_adapter(
                    self.action_adapters[flat_key])
                self.distributions[flat_key] = Distribution.from_spec(
                    dist_spec, scope="{}-{}".format(dist_spec["type"], i))
コード例 #3
0
    def test_simple_action_adapter(self):
        # Last NN layer.
        last_nn_layer_space = FloatBox(shape=(16, ), add_batch_rank=True)
        # Action Space.
        action_space = IntBox(2, shape=(3, 2))

        action_adapter = ActionAdapter(action_space=action_space,
                                       weights_spec=1.0,
                                       biases_spec=False,
                                       activation="relu")
        test = ComponentTest(component=action_adapter,
                             input_spaces=dict(nn_output=last_nn_layer_space),
                             action_space=action_space)
        action_adapter_params = test.read_variable_values(
            action_adapter.variables)

        # Batch of 2 samples.
        inputs = last_nn_layer_space.sample(2)

        expected_action_layer_output = np.matmul(
            inputs,
            action_adapter_params["action-adapter/action-layer/dense/kernel"])
        test.test(("get_action_layer_output", inputs),
                  expected_outputs=dict(output=expected_action_layer_output))

        expected_logits = np.reshape(expected_action_layer_output,
                                     newshape=(2, 3, 2, 2))
        expected_probabilities = softmax(expected_logits)
        expected_log_probs = np.log(expected_probabilities)
        test.test(("get_logits_probabilities_log_probs", inputs),
                  expected_outputs=dict(logits=expected_logits,
                                        probabilities=expected_probabilities,
                                        log_probs=expected_log_probs))
コード例 #4
0
    def test_action_adapter_with_complex_lstm_output(self):
        # Last NN layer (LSTM with time rank).
        last_nn_layer_space = FloatBox(shape=(4,), add_batch_rank=True, add_time_rank=True, time_major=True)
        # Action Space.
        action_space = IntBox(2, shape=(3, 2))

        action_adapter = ActionAdapter(action_space=action_space, biases_spec=False)
        test = ComponentTest(
            component=action_adapter, input_spaces=dict(
                nn_output=last_nn_layer_space,
                inputs=[last_nn_layer_space]
            ), action_space=action_space
        )
        action_adapter_params = test.read_variable_values(action_adapter.variables)

        # Batch of 2 samples, 3 timesteps.
        inputs = last_nn_layer_space.sample(size=(3, 2))
        # Fold time rank before the action layer pass through.
        inputs_reshaped = np.reshape(inputs, newshape=(6, -1))
        # Action layer pass through and unfolding of time rank.
        expected_action_layer_output = np.matmul(
            inputs_reshaped, action_adapter_params["action-adapter/action-network/action-layer/dense/kernel"]
        ).reshape((3, 2, -1))
        # Logits (already well reshaped (same as action space)).
        expected_logits = np.reshape(expected_action_layer_output, newshape=(3, 2, 3, 2, 2))
        test.test(("apply", inputs), expected_outputs=dict(output=expected_logits))
        test.test(("get_logits", inputs), expected_outputs=expected_logits)

        # Softmax (probs).
        expected_probabilities = softmax(expected_logits)
        # Log probs.
        expected_log_probs = np.log(expected_probabilities)
        test.test(("get_logits_probabilities_log_probs", inputs), expected_outputs=dict(
            logits=expected_logits, probabilities=expected_probabilities, log_probs=expected_log_probs
        ), decimals=5)
コード例 #5
0
    def test_exploration_with_continuous_action_space(self):
        # TODO not portable, redo with more general mean/stddev checks over a sample of distributed outputs.
        return
        # 2x2 action-pick, each composite action with 5 categories.
        action_space = FloatBox(shape=(2,2), add_batch_rank=True)

        distribution = Normal()
        action_adapter = ActionAdapter(action_space=action_space)

        # Our distribution to go into the Exploration object.
        nn_output_space = FloatBox(shape=(13,), add_batch_rank=True)  # 13: Any flat nn-output should be ok.

        exploration = Exploration.from_spec(dict(noise_spec=dict(type="gaussian_noise", mean=10.0, stddev=2.0)))

        # The Component to test.
        exploration_pipeline = Component(scope="continuous-plus-noise")
        exploration_pipeline.add_components(action_adapter, distribution, exploration, scope="exploration-pipeline")

        @rlgraph_api(component=exploration_pipeline)
        def get_action(self_, nn_output):
            _, parameters, _ = action_adapter.get_logits_probabilities_log_probs(nn_output)
            sample_stochastic = distribution.sample_stochastic(parameters)
            sample_deterministic = distribution.sample_deterministic(parameters)
            action = exploration.get_action(sample_stochastic, sample_deterministic)
            return action

        @rlgraph_api(component=exploration_pipeline)
        def get_noise(self_):
            return exploration.noise_component.get_noise()

        test = ComponentTest(component=exploration_pipeline, input_spaces=dict(nn_output=nn_output_space),
                             action_space=action_space)

        # Collect outputs in `collected` list to compare moments.
        collected = list()
        for _ in range_(1000):
            test.test("get_noise", fn_test=lambda component_test, outs: collected.append(outs))

        self.assertAlmostEqual(10.0, np.mean(collected), places=1)
        self.assertAlmostEqual(2.0, np.std(collected), places=1)

        np.random.seed(10)
        input_ = nn_output_space.sample(size=3)
        expected = np.array([[[13.163095, 8.46925],
                              [10.375976, 5.4675055]],
                             [[13.239931, 7.990649],
                              [10.03761, 10.465796]],
                             [[10.280741, 7.2384844],
                              [10.040194, 8.248206]]], dtype=np.float32)
        test.test(("get_action", input_), expected_outputs=expected, decimals=3)
コード例 #6
0
    def test_simple_action_adapter_with_batch_apply(self):
        # Last NN layer.
        last_nn_layer_space = FloatBox(shape=(16, ),
                                       add_batch_rank=True,
                                       add_time_rank=True,
                                       time_major=True)
        # Action Space.
        action_space = IntBox(2, shape=(3, 2))

        action_adapter = ActionAdapter(action_space=action_space,
                                       weights_spec=1.0,
                                       biases_spec=False,
                                       fold_time_rank=True,
                                       unfold_time_rank=True,
                                       activation="relu")
        test = ComponentTest(component=action_adapter,
                             input_spaces=dict(nn_output=last_nn_layer_space,
                                               inputs=[last_nn_layer_space]),
                             action_space=action_space)
        action_adapter_params = test.read_variable_values(
            action_adapter.variables)

        # Batch of (4, 5).
        inputs = last_nn_layer_space.sample(size=(4, 5))
        inputs_folded = np.reshape(inputs, newshape=(20, -1))

        expected_action_layer_output = np.matmul(
            inputs_folded, action_adapter_params[
                "action-adapter/action-network/action-layer/dense/kernel"])
        expected_logits = np.reshape(expected_action_layer_output,
                                     newshape=(4, 5, 3, 2, 2))

        test.test(("apply", inputs),
                  expected_outputs=dict(output=expected_logits),
                  decimals=4)
        test.test(("get_logits", inputs),
                  expected_outputs=expected_logits,
                  decimals=4)

        expected_probabilities = softmax(expected_logits)
        expected_log_probs = np.log(expected_probabilities)
        test.test(("get_logits_probabilities_log_probs", inputs),
                  expected_outputs=dict(logits=expected_logits,
                                        probabilities=expected_probabilities,
                                        log_probs=expected_log_probs),
                  decimals=4)
コード例 #7
0
    def test_exploration_with_discrete_action_space(self):
        nn_output_space = FloatBox(shape=(13, ), add_batch_rank=True)
        time_step_space = IntBox(10000)
        # 2x2 action-pick, each composite action with 5 categories.
        action_space = IntBox(5, shape=(2, 2), add_batch_rank=True)

        # Our distribution to go into the Exploration object.
        distribution = Categorical()
        action_adapter = ActionAdapter(action_space=action_space)

        exploration = Exploration.from_spec(
            dict(epsilon_spec=dict(decay_spec=dict(type="linear_decay",
                                                   from_=1.0,
                                                   to_=0.0,
                                                   start_timestep=0,
                                                   num_timesteps=10000))))
        # The Component to test.
        exploration_pipeline = Component(action_adapter,
                                         distribution,
                                         exploration,
                                         scope="exploration-pipeline")

        @rlgraph_api(component=exploration_pipeline)
        def get_action(self_, nn_output, time_step):
            out = action_adapter.get_logits_probabilities_log_probs(nn_output)
            sample = distribution.sample_deterministic(out["probabilities"])
            action = exploration.get_action(sample, time_step)
            return action

        test = ComponentTest(component=exploration_pipeline,
                             input_spaces=dict(nn_output=nn_output_space,
                                               time_step=int),
                             action_space=action_space)

        # With exploration: Check, whether actions are equally distributed.
        nn_outputs = nn_output_space.sample(2)
        time_steps = time_step_space.sample(30)
        # Collect action-batch-of-2 for each of our various random time steps.
        # Each action is an int box of shape=(2,2)
        actions = np.ndarray(shape=(30, 2, 2, 2), dtype=np.int)
        for i, time_step in enumerate(time_steps):
            actions[i] = test.test(("get_action", [nn_outputs, time_step]),
                                   expected_outputs=None)

        # Assert some distribution of the actions.
        mean_action = actions.mean()
        stddev_action = actions.std()
        self.assertAlmostEqual(mean_action, 2.0, places=0)
        self.assertAlmostEqual(stddev_action, 1.0, places=0)

        # Without exploration (epsilon is force-set to 0.0): Check, whether actions are always the same
        # (given same nn_output all the time).
        nn_outputs = nn_output_space.sample(2)
        time_steps = time_step_space.sample(30) + 10000
        # Collect action-batch-of-2 for each of our various random time steps.
        # Each action is an int box of shape=(2,2)
        actions = np.ndarray(shape=(30, 2, 2, 2), dtype=np.int)
        for i, time_step in enumerate(time_steps):
            actions[i] = test.test(("get_action", [nn_outputs, time_step]),
                                   expected_outputs=None)

        # Assert zero stddev of the single action components.
        stddev_action_a = actions[:, 0, 0, 0].std(
        )  # batch item 0, action-component (0,0)
        self.assertAlmostEqual(stddev_action_a, 0.0, places=1)
        stddev_action_b = actions[:, 1, 1, 0].std(
        )  # batch item 1, action-component (1,0)
        self.assertAlmostEqual(stddev_action_b, 0.0, places=1)
        stddev_action_c = actions[:, 0, 0, 1].std(
        )  # batch item 0, action-component (0,1)
        self.assertAlmostEqual(stddev_action_c, 0.0, places=1)
        stddev_action_d = actions[:, 1, 1, 1].std(
        )  # batch item 1, action-component (1,1)
        self.assertAlmostEqual(stddev_action_d, 0.0, places=1)
        self.assertAlmostEqual(actions.std(), 1.0, places=0)
コード例 #8
0
    def test_exploration_with_discrete_container_action_space(self):
        nn_output_space = FloatBox(shape=(12, ), add_batch_rank=True)
        time_step_space = IntBox(10000)
        # Some container action space.
        action_space = Dict(dict(a=IntBox(3), b=IntBox(2), c=IntBox(4)),
                            add_batch_rank=True)

        # Our distribution to go into the Exploration object.
        distribution_a = Categorical(scope="d_a")
        distribution_b = Categorical(scope="d_b")
        distribution_c = Categorical(scope="d_c")
        action_adapter_a = ActionAdapter(action_space=action_space["a"],
                                         scope="aa_a")
        action_adapter_b = ActionAdapter(action_space=action_space["b"],
                                         scope="aa_b")
        action_adapter_c = ActionAdapter(action_space=action_space["c"],
                                         scope="aa_c")

        exploration = Exploration.from_spec(
            dict(epsilon_spec=dict(decay_spec=dict(type="linear_decay",
                                                   from_=1.0,
                                                   to_=0.0,
                                                   start_timestep=0,
                                                   num_timesteps=10000))))
        # The Component to test.
        exploration_pipeline = Component(action_adapter_a,
                                         action_adapter_b,
                                         action_adapter_c,
                                         distribution_a,
                                         distribution_b,
                                         distribution_c,
                                         exploration,
                                         scope="exploration-pipeline")

        @rlgraph_api(component=exploration_pipeline)
        def get_action(self_, nn_output, time_step):
            out_a = action_adapter_a.get_logits_probabilities_log_probs(
                nn_output)
            out_b = action_adapter_b.get_logits_probabilities_log_probs(
                nn_output)
            out_c = action_adapter_c.get_logits_probabilities_log_probs(
                nn_output)
            sample_a = distribution_a.sample_deterministic(
                out_a["probabilities"])
            sample_b = distribution_b.sample_deterministic(
                out_b["probabilities"])
            sample_c = distribution_c.sample_deterministic(
                out_c["probabilities"])
            sample = self_._graph_fn_merge_actions(sample_a, sample_b,
                                                   sample_c)
            action = exploration.get_action(sample, time_step)
            return action

        @graph_fn(component=exploration_pipeline)
        def _graph_fn_merge_actions(self, a, b, c):
            return DataOpDict(a=a, b=b, c=c)

        test = ComponentTest(component=exploration_pipeline,
                             input_spaces=dict(nn_output=nn_output_space,
                                               time_step=int),
                             action_space=action_space)

        # With exploration: Check, whether actions are equally distributed.
        batch_size = 2
        num_time_steps = 30
        nn_outputs = nn_output_space.sample(batch_size)
        time_steps = time_step_space.sample(num_time_steps)
        # Collect action-batch-of-2 for each of our various random time steps.
        actions_a = np.ndarray(shape=(num_time_steps, batch_size),
                               dtype=np.int)
        actions_b = np.ndarray(shape=(num_time_steps, batch_size),
                               dtype=np.int)
        actions_c = np.ndarray(shape=(num_time_steps, batch_size),
                               dtype=np.int)
        for i, t in enumerate(time_steps):
            a = test.test(("get_action", [nn_outputs, t]),
                          expected_outputs=None)
            actions_a[i] = a["a"]
            actions_b[i] = a["b"]
            actions_c[i] = a["c"]

        # Assert some distribution of the actions.
        mean_action_a = actions_a.mean()
        stddev_action_a = actions_a.std()
        self.assertAlmostEqual(mean_action_a, 1.0, places=0)
        self.assertAlmostEqual(stddev_action_a, 1.0, places=0)
        mean_action_b = actions_b.mean()
        stddev_action_b = actions_b.std()
        self.assertAlmostEqual(mean_action_b, 0.5, places=0)
        self.assertAlmostEqual(stddev_action_b, 0.5, places=0)
        mean_action_c = actions_c.mean()
        stddev_action_c = actions_c.std()
        self.assertAlmostEqual(mean_action_c, 1.5, places=0)
        self.assertAlmostEqual(stddev_action_c, 1.0, places=0)

        # Without exploration (epsilon is force-set to 0.0): Check, whether actions are always the same
        # (given same nn_output all the time).
        nn_outputs = nn_output_space.sample(batch_size)
        time_steps = time_step_space.sample(num_time_steps) + 10000
        # Collect action-batch-of-2 for each of our various random time steps.
        actions_a = np.ndarray(shape=(num_time_steps, batch_size),
                               dtype=np.int)
        actions_b = np.ndarray(shape=(num_time_steps, batch_size),
                               dtype=np.int)
        actions_c = np.ndarray(shape=(num_time_steps, batch_size),
                               dtype=np.int)
        for i, t in enumerate(time_steps):
            a = test.test(("get_action", [nn_outputs, t]),
                          expected_outputs=None)
            actions_a[i] = a["a"]
            actions_b[i] = a["b"]
            actions_c[i] = a["c"]

        # Assert zero stddev of the single action components.
        stddev_action = actions_a[:,
                                  0].std()  # batch item 0, action-component a
        self.assertAlmostEqual(stddev_action, 0.0, places=1)
        stddev_action = actions_a[:,
                                  1].std()  # batch item 1, action-component a
        self.assertAlmostEqual(stddev_action, 0.0, places=1)

        stddev_action = actions_b[:,
                                  0].std()  # batch item 0, action-component b
        self.assertAlmostEqual(stddev_action, 0.0, places=1)
        stddev_action = actions_b[:,
                                  1].std()  # batch item 1, action-component b
        self.assertAlmostEqual(stddev_action, 0.0, places=1)

        stddev_action = actions_c[:,
                                  0].std()  # batch item 0, action-component c
        self.assertAlmostEqual(stddev_action, 0.0, places=1)
        stddev_action = actions_c[:,
                                  1].std()  # batch item 1, action-component c
        self.assertAlmostEqual(stddev_action, 0.0, places=1)
コード例 #9
0
ファイル: policy.py プロジェクト: samialabed/rlgraph
    def __init__(self,
                 network_spec,
                 action_space=None,
                 action_adapter_spec=None,
                 deterministic=True,
                 scope="policy",
                 **kwargs):
        """
        Args:
            network_spec (Union[NeuralNetwork,dict]): The NeuralNetwork Component or a specification dict to build
                one.

            action_space (Space): The action Space within which this Component will create actions.

            action_adapter_spec (Optional[dict]): A spec-dict to create an ActionAdapter. Use None for the default
                ActionAdapter object.

            deterministic (bool): Whether to pick actions according to the max-likelihood value or via sampling.
                Default: True.

            batch_apply (bool): Whether to wrap both the NN and the ActionAdapter with a BatchApply Component in order
                to fold time rank into batch rank before a forward pass.
        """
        super(Policy, self).__init__(scope=scope, **kwargs)

        self.neural_network = NeuralNetwork.from_spec(
            network_spec)  # type: NeuralNetwork

        # Create the necessary action adapters for this Policy. One for each action space component.
        self.action_adapters = dict()
        if action_space is None:
            self.action_adapters[""] = ActionAdapter.from_spec(
                action_adapter_spec)
            self.action_space = self.action_adapters[""].action_space
            # Assert single component action space.
            assert len(self.action_space.flatten()) == 1,\
                "ERROR: Action space must not be ContainerSpace if no `action_space` is given in Policy c'tor!"
        else:
            self.action_space = Space.from_spec(action_space)
            for i, (flat_key, action_component) in enumerate(
                    self.action_space.flatten().items()):
                if action_adapter_spec is not None:
                    aa_spec = action_adapter_spec.get(flat_key,
                                                      action_adapter_spec)
                    aa_spec["action_space"] = action_component
                else:
                    aa_spec = dict(action_space=action_component)
                self.action_adapters[flat_key] = ActionAdapter.from_spec(
                    aa_spec, scope="action-adapter-{}".format(i))

        self.deterministic = deterministic

        # Figure out our Distributions.
        self.distributions = dict()
        for i, (flat_key, action_component) in enumerate(
                self.action_space.flatten().items()):
            if isinstance(action_component, IntBox):
                self.distributions[flat_key] = Categorical(
                    scope="categorical-{}".format(i))
            # Continuous action space -> Normal distribution (each action needs mean and variance from network).
            elif isinstance(action_component, FloatBox):
                self.distributions[flat_key] = Normal(
                    scope="normal-{}".format(i))
            else:
                raise RLGraphError(
                    "ERROR: `action_component` is of type {} and not allowed in {} Component!"
                    .format(type(action_space).__name__, self.name))

        self.add_components(*[self.neural_network] +
                            list(self.action_adapters.values()) +
                            list(self.distributions.values()))
コード例 #10
0
    def __init__(self,
                 network_spec,
                 action_space=None,
                 action_adapter_spec=None,
                 max_likelihood=True,
                 scope="policy",
                 **kwargs):
        """
        Args:
            network_spec (Union[NeuralNetwork,dict]): The NeuralNetwork Component or a specification dict to build
                one.

            action_space (Space): The action Space within which this Component will create actions.

            action_adapter_spec (Optional[dict]): A spec-dict to create an ActionAdapter. Use None for the default
                ActionAdapter object.

            max_likelihood (bool): Whether to pick actions according to the max-likelihood value or via sampling.
                Default: True.
        """
        super(Policy, self).__init__(scope=scope, **kwargs)

        self.neural_network = NeuralNetwork.from_spec(network_spec)
        if action_space is None:
            self.action_adapter = ActionAdapter.from_spec(action_adapter_spec)
            action_space = self.action_adapter.action_space
        else:
            self.action_adapter = ActionAdapter.from_spec(
                action_adapter_spec, action_space=action_space)
        self.action_space = action_space
        self.max_likelihood = max_likelihood

        # TODO: Hacky trick to implement IMPALA post-LSTM256 time-rank folding and unfolding.
        # TODO: Replace entirely via sonnet-like BatchApply Component.
        is_impala = "IMPALANetwork" in type(self.neural_network).__name__

        # Add API-method to get baseline output (if we use an extra value function baseline node).
        if isinstance(self.action_adapter, BaselineActionAdapter):
            # TODO: IMPALA attempt to speed up final pass after LSTM.
            if is_impala:
                self.time_rank_folder = ReShape(fold_time_rank=True,
                                                scope="time-rank-fold")
                self.time_rank_unfolder_v = ReShape(unfold_time_rank=True,
                                                    time_major=True,
                                                    scope="time-rank-unfold-v")
                self.time_rank_unfolder_a_probs = ReShape(
                    unfold_time_rank=True,
                    time_major=True,
                    scope="time-rank-unfold-a-probs")
                self.time_rank_unfolder_logits = ReShape(
                    unfold_time_rank=True,
                    time_major=True,
                    scope="time-rank-unfold-logits")
                self.time_rank_unfolder_log_probs = ReShape(
                    unfold_time_rank=True,
                    time_major=True,
                    scope="time-rank-unfold-log-probs")
                self.add_components(self.time_rank_folder,
                                    self.time_rank_unfolder_v,
                                    self.time_rank_unfolder_a_probs,
                                    self.time_rank_unfolder_log_probs,
                                    self.time_rank_unfolder_logits)

            @rlgraph_api(component=self)
            def get_state_values_logits_probabilities_log_probs(
                    self, nn_input, internal_states=None):
                nn_output = self.neural_network.apply(nn_input,
                                                      internal_states)
                last_internal_states = nn_output.get("last_internal_states")
                nn_output = nn_output["output"]

                # TODO: IMPALA attempt to speed up final pass after LSTM.
                if is_impala:
                    nn_output = self.time_rank_folder.apply(nn_output)

                out = self.action_adapter.get_logits_probabilities_log_probs(
                    nn_output)

                # TODO: IMPALA attempt to speed up final pass after LSTM.
                if is_impala:
                    state_values = self.time_rank_unfolder_v.apply(
                        out["state_values"], nn_output)
                    logits = self.time_rank_unfolder_logits.apply(
                        out["logits"], nn_output)
                    probs = self.time_rank_unfolder_a_probs.apply(
                        out["probabilities"], nn_output)
                    log_probs = self.time_rank_unfolder_log_probs.apply(
                        out["log_probs"], nn_output)
                else:
                    state_values = out["state_values"]
                    logits = out["logits"]
                    probs = out["probabilities"]
                    log_probs = out["log_probs"]

                return dict(state_values=state_values,
                            logits=logits,
                            probabilities=probs,
                            log_probs=log_probs,
                            last_internal_states=last_internal_states)

        # Figure out our Distribution.
        if isinstance(action_space, IntBox):
            self.distribution = Categorical()
        # Continuous action space -> Normal distribution (each action needs mean and variance from network).
        elif isinstance(action_space, FloatBox):
            self.distribution = Normal()
        else:
            raise RLGraphError(
                "ERROR: `action_space` is of type {} and not allowed in {} Component!"
                .format(type(action_space).__name__, self.name))

        self.add_components(self.neural_network, self.action_adapter,
                            self.distribution)

        if is_impala:
            self.add_components(self.time_rank_folder,
                                self.time_rank_unfolder_v,
                                self.time_rank_unfolder_a_probs,
                                self.time_rank_unfolder_log_probs,
                                self.time_rank_unfolder_logits)