예제 #1
0
class BatchApply(Component):
    """
    Takes an input with batch and time ranks, then folds the time rank into the batch rank,
    calls a certain API of some arbitrary child component, and unfolds the time rank again.
    """
    def __init__(self,
                 sub_component,
                 api_method_name,
                 scope="batch-apply",
                 **kwargs):
        """
        Args:
            sub_component (Component): The sub-Component to apply the batch to.
            api_method_name (str): The name of the API-method to call on the sub-component.
        """
        super(BatchApply, self).__init__(scope=scope, **kwargs)

        self.sub_component = sub_component
        self.api_method_name = api_method_name

        # Create the necessary reshape components.
        self.folder = ReShape(fold_time_rank=True, scope="folder")
        self.unfolder = ReShape(unfold_time_rank=True, scope="unfolder")

        self.add_components(self.sub_component, self.folder, self.unfolder)

    @rlgraph_api
    def apply(self, input_):
        folded = self._graph_fn_fold(input_)
        applied = self._graph_fn_apply(folded)
        unfolded = self._graph_fn_unfold(applied, input_)
        return unfolded

    @graph_fn(flatten_ops=True, split_ops=True)
    def _graph_fn_fold(self, input_):
        if get_backend() == "tf":
            # Fold the time rank.
            input_folded = self.folder.apply(input_)
            return input_folded

    @graph_fn
    def _graph_fn_apply(self, input_folded):
        if get_backend() == "tf":
            # Send the folded input through the sub-component.
            sub_component_out = getattr(self.sub_component,
                                        self.api_method_name)(input_folded)
            return sub_component_out

    @graph_fn(flatten_ops=True, split_ops=True)
    def _graph_fn_unfold(self, sub_component_out, orig_input):
        if get_backend() == "tf":
            # Un-fold the time rank again.
            output = self.unfolder.apply(
                sub_component_out, input_before_time_rank_folding=orig_input)
            return output
예제 #2
0
class IMPALANetwork(NeuralNetwork):
    """
    The base class for both "large and small architecture" versions of the networks used in [1].

    [1] IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures - Espeholt, Soyer,
        Munos et al. - 2018 (https://arxiv.org/abs/1802.01561)
    """
    def __init__(self,
                 worker_sample_size=100,
                 scope="impala-network",
                 **kwargs):
        """
        Args:
            worker_sample_size (int): How many time-steps an IMPALA actor will have performed in one rollout.
        """
        super(IMPALANetwork, self).__init__(scope=scope, **kwargs)

        self.worker_sample_size = worker_sample_size

        # Create all needed sub-components.

        # ContainerSplitter for the Env signal (dict of 4 keys: for env image, env text, previous action and reward).
        self.splitter = ContainerSplitter("RGB_INTERLEAVED",
                                          "INSTR",
                                          "previous_action",
                                          "previous_reward",
                                          scope="input-splitter")

        # Fold the time rank into the batch rank.
        self.time_rank_fold_before_lstm = ReShape(
            fold_time_rank=True, scope="time-rank-fold-before-lstm")
        self.time_rank_unfold_before_lstm = ReShape(
            unfold_time_rank=True,
            time_major=True,
            scope="time-rank-unfold-before-lstm")

        # The Image Processing Stack (left side of "Large Architecture" Figure 3 in [1]).
        # Conv2D column + ReLU + fc(256) + ReLU.
        self.image_processing_stack = self.build_image_processing_stack()

        # The text processing pipeline: Takes a batch of string tensors as input, creates a hash-bucket thereof,
        # and passes the output of the hash bucket through an embedding-lookup(20) layer. The output of the embedding
        # lookup is then passed through an LSTM(64).
        self.text_processing_stack = self.build_text_processing_stack()

        #self.debug_slicer = Slice(scope="internal-states-slicer", squeeze=True)

        # The concatenation layer (concatenates outputs from image/text processing stacks, previous action/reward).
        self.concat_layer = ConcatLayer()

        # The main LSTM (going into the ActionAdapter (next in the Policy Component that uses this NN Component)).
        # Use time-major as it's faster (say tf docs).
        self.main_lstm = LSTMLayer(units=256,
                                   scope="lstm-256",
                                   time_major=True,
                                   static_loop=self.worker_sample_size)

        # Add all sub-components to this one.
        self.add_components(
            self.splitter,
            self.image_processing_stack,
            self.text_processing_stack,
            self.concat_layer,
            self.main_lstm,
            self.time_rank_fold_before_lstm,
            self.time_rank_unfold_before_lstm,
            #self.debug_slicer
        )

    @staticmethod
    def build_image_processing_stack():
        """
        Builds the image processing pipeline for IMPALA and returns it.
        """
        raise NotImplementedError

    @staticmethod
    def build_text_processing_stack():
        """
        Helper function to build the text processing pipeline for both the large and small architectures, consisting of:
        - ReShape preprocessor to fold the incoming time rank into the batch rank.
        - StringToHashBucket Layer taking a batch of sentences and converting them to an indices-table of dimensions:
          cols=length of longest sentences in input
          rows=number of items in the batch
          The cols dimension could be interpreted as the time rank into a consecutive LSTM. The StringToHashBucket
          Component returns the sequence length of each batch item for exactly that purpose.
        - Embedding Lookup Layer of embedding size 20 and number of rows == num_hash_buckets (see previous layer).
        - LSTM processing the batched sequences of words coming from the embedding layer as batches of rows.
        """
        num_hash_buckets = 1000

        # Create a hash bucket from the sentences and use that bucket to do an embedding lookup (instead of
        # a vocabulary).
        string_to_hash_bucket = StringToHashBucket(
            num_hash_buckets=num_hash_buckets)
        embedding = EmbeddingLookup(embed_dim=20,
                                    vocab_size=num_hash_buckets,
                                    pad_empty=True)
        # The time rank for the LSTM is now the sequence of words in a sentence, NOT the original env time rank.
        # We will only use the last output of the LSTM-64 for further processing as that is the output after having
        # seen all words in the sentence.
        # The original env stepping time rank is currently folded into the batch rank and must be unfolded again before
        # passing it into the main LSTM.
        lstm64 = LSTMLayer(units=64, scope="lstm-64", time_major=False)

        tuple_splitter = ContainerSplitter(tuple_length=2,
                                           scope="tuple-splitter")

        def custom_apply(self, inputs):
            hash_bucket, lengths = self.sub_components[
                "string-to-hash-bucket"].apply(inputs)

            embedding_output = self.sub_components["embedding-lookup"].apply(
                hash_bucket)

            # Return only the last output (sentence of words, where we are not interested in intermediate results
            # where the LSTM has not seen the entire sentence yet).
            # Last output is the final internal h-state (slot 1 in the returned LSTM tuple; slot 0 is final c-state).
            lstm_output = self.sub_components["lstm-64"].apply(
                embedding_output, sequence_length=lengths)
            lstm_final_internals = lstm_output["last_internal_states"]

            # Need to split once more because the LSTM state is always a tuple of final c- and h-states.
            _, lstm_final_h_state = self.sub_components[
                "tuple-splitter"].split(lstm_final_internals)

            return lstm_final_h_state

        text_processing_stack = Stack(string_to_hash_bucket,
                                      embedding,
                                      lstm64,
                                      tuple_splitter,
                                      api_methods={("apply", custom_apply)},
                                      scope="text-stack")

        return text_processing_stack

    @rlgraph_api
    def apply(self, input_dict, internal_states=None):
        # Split the input dict coming directly from the Env.
        _, _, _, orig_previous_reward = self.splitter.split(input_dict)

        folded_input = self.time_rank_fold_before_lstm.apply(input_dict)
        image, text, previous_action, previous_reward = self.splitter.split(
            folded_input)

        # Get the left-stack (image) and right-stack (text) output (see [1] for details).
        text_processing_output = self.text_processing_stack.apply(text)
        image_processing_output = self.image_processing_stack.apply(image)

        # Concat everything together.
        concatenated_data = self.concat_layer.apply(image_processing_output,
                                                    text_processing_output,
                                                    previous_action,
                                                    previous_reward)

        unfolded_concatenated_data = self.time_rank_unfold_before_lstm.apply(
            concatenated_data, orig_previous_reward)

        # Feed concat'd input into main LSTM(256).
        lstm_output = self.main_lstm.apply(unfolded_concatenated_data,
                                           internal_states)

        return lstm_output
예제 #3
0
class ActionAdapter(Component):
    """
    A Component that cleans up a neural network's flat output and gets it ready for parameterizing a
    Distribution Component.
    Processing steps include:
    - Sending the raw, flattened NN output through a Dense layer whose number of units matches the flattened
    action space.
    - Reshaping (according to the action Space).
    - Translating the reshaped outputs (logits) into probabilities (by softmaxing) and log-probabilities (log).
    """
    def __init__(self,
                 action_space,
                 add_units=0,
                 units=None,
                 weights_spec=None,
                 biases_spec=None,
                 activation=None,
                 scope="action-adapter",
                 **kwargs):
        """
        Args:
            action_space (Space): The action Space within which this Component will create actions.

            add_units (Optional[int]): An optional number of units to add to the auto-calculated number of action-
                layer nodes. Can be negative to subtract units from the auto-calculated value.
                NOTE: Only one of either `add_units` or `units` must be provided.

            units (Optional[int]): An optional number of units to use for the action-layer. If None, will calculate
                the number of units automatically from the given action_space.
                NOTE: Only one of either `add_units` or `units` must be provided.

            weights_spec (Optional[any]): An optional RLGraph Initializer spec that will be used to initialize the
                weights of `self.action layer`. Default: None (use default initializer).

            biases_spec (Optional[any]): An optional RLGraph Initializer spec that will be used to initialize the
                biases of `self.action layer`. Default: None (use default initializer, which is usually 0.0).

            activation (Optional[str]): The activation function to use for `self.action_layer`.
                Default: None (=linear).
        """
        super(ActionAdapter, self).__init__(scope=scope, **kwargs)

        self.action_space = action_space.with_batch_rank()
        self.weights_spec = weights_spec
        self.biases_spec = biases_spec
        self.activation = activation

        # Our (dense) action layer representing the flattened action space.
        self.action_layer = None

        # Calculate the number of nodes in the action layer (DenseLayer object) depending on our action Space
        # or using a given fixed number (`units`).
        # Also generate the ReShape sub-Component and give it the new_shape.
        if isinstance(self.action_space, IntBox):
            if units is None:
                units = add_units + self.action_space.flat_dim_with_categories
            self.reshape = ReShape(
                new_shape=self.action_space.get_shape(with_category_rank=True),
                flatten_categories=False)
        else:
            if units is None:
                units = add_units + 2 * self.action_space.flat_dim  # Those two dimensions are the mean and log sd
            # Manually add moments after batch/time ranks.
            new_shape = tuple([2] + list(self.action_space.shape))
            self.reshape = ReShape(new_shape=new_shape)

        assert units > 0, "ERROR: Number of nodes for action-layer calculated as {}! Must be larger 0.".format(
            units)

        # Create the action-layer and add it to this component.
        self.action_layer = DenseLayer(units=units,
                                       activation=self.activation,
                                       weights_spec=self.weights_spec,
                                       biases_spec=self.biases_spec,
                                       scope="action-layer")

        self.add_components(self.action_layer, self.reshape)

    def check_input_spaces(self, input_spaces, action_space=None):
        # Check the input Space.
        last_nn_layer_space = input_spaces["nn_output"]  # type: Space
        sanity_check_space(last_nn_layer_space,
                           non_allowed_types=[ContainerSpace])

        # Check the action Space.
        sanity_check_space(self.action_space, must_have_batch_rank=True)
        if isinstance(self.action_space, IntBox):
            sanity_check_space(self.action_space, must_have_categories=True)
        else:
            # Fixme: Are there other restraints on continuous action spaces? E.g. no dueling layers?
            pass

    @rlgraph_api
    def get_action_layer_output(self, nn_output):
        """
        Returns the raw, non-reshaped output of the action-layer (DenseLayer) after passing through it the raw
        nn_output (coming from the previous Component).

        Args:
            nn_output (DataOpRecord): The NN output of the preceding neural network.

        Returns:
            DataOpRecord: The output of the action layer (a DenseLayer) after passing `nn_output` through it.
        """
        out = self.action_layer.apply(nn_output)
        return dict(output=out)

    @rlgraph_api
    def get_logits(self, nn_output):
        """
        Args:
            nn_output (DataOpRecord): The NN output of the preceding neural network.

        Returns:
            SingleDataOp: The logits (raw nn_output, BUT reshaped).
        """
        aa_output = self.get_action_layer_output(nn_output)
        logits = self.reshape.apply(aa_output["output"])
        return logits

    @rlgraph_api
    def get_logits_probabilities_log_probs(self, nn_output):
        """
        Args:
            nn_output (DataOpRecord): The NN output of the preceding neural network.

        Returns:
            Tuple[SingleDataOp]:
                - logits (raw nn_output, BUT reshaped)
                - probabilities (softmaxed(logits))
                - log(probabilities)
        """
        logits = self.get_logits(nn_output)
        probabilities, log_probs = self._graph_fn_get_probabilities_log_probs(
            logits)
        return dict(logits=logits,
                    probabilities=probabilities,
                    log_probs=log_probs)

    # TODO: Use a SoftMax Component instead (uses the same code as the one below).
    @graph_fn
    def _graph_fn_get_probabilities_log_probs(self, logits):
        """
        Creates properties/parameters and log-probs from some reshaped output.

        Args:
            logits (SingleDataOp): The output of some layer that is already reshaped
                according to our action Space.

        Returns:
            tuple (2x SingleDataOp):
                parameters (DataOp): The parameters, ready to be passed to a Distribution object's
                    get_distribution API-method (usually some probabilities or loc/scale pairs).
                log_probs (DataOp): Simply the log(parameters).
        """
        if get_backend() == "tf":
            if isinstance(self.action_space, IntBox):
                # Discrete actions.
                parameters = tf.maximum(x=tf.nn.softmax(logits=logits,
                                                        axis=-1),
                                        y=SMALL_NUMBER)
                # Log probs.
                log_probs = tf.log(x=parameters)
            elif isinstance(self.action_space, FloatBox):
                # Continuous actions.
                mean, log_sd = tf.split(value=logits,
                                        num_or_size_splits=2,
                                        axis=1)
                # Remove moments rank.
                mean = tf.squeeze(input=mean, axis=1)
                log_sd = tf.squeeze(input=log_sd, axis=1)

                # Clip log_sd. log(SMALL_NUMBER) is negative.
                log_sd = tf.clip_by_value(t=log_sd,
                                          clip_value_min=log(SMALL_NUMBER),
                                          clip_value_max=-log(SMALL_NUMBER))

                # Turn log sd into sd.
                sd = tf.exp(x=log_sd)

                parameters = DataOpTuple(mean, sd)
                log_probs = DataOpTuple(tf.log(x=mean), log_sd)
            else:
                raise NotImplementedError

            return parameters, log_probs

        elif get_backend() == "pytorch":
            if isinstance(self.action_space, IntBox):
                # Discrete actions.
                softmax_logits = torch.softmax(logits, dim=-1)
                parameters = torch.max(softmax_logits, SMALL_NUMBER_TORCH)
                # Log probs.
                log_probs = torch.log(parameters)
            elif isinstance(self.action_space, FloatBox):
                # Continuous actions.
                mean, log_sd = torch.split(logits,
                                           split_size_or_sections=2,
                                           dim=1)
                # Remove moments rank.
                mean = torch.squeeze(mean, dim=1)
                log_sd = torch.squeeze(log_sd, dim=1)

                # Clip log_sd. log(SMALL_NUMBER) is negative.
                log_sd = torch.clamp(log_sd,
                                     min=LOG_SMALL_NUMBER,
                                     max=-LOG_SMALL_NUMBER)

                # Turn log sd into sd.
                sd = torch.exp(log_sd)

                parameters = DataOpTuple(mean, sd)
                log_probs = DataOpTuple(torch.log(mean), log_sd)
            else:
                raise NotImplementedError

            return parameters, log_probs
예제 #4
0
class Policy(Component):
    """
    A Policy is a wrapper Component that contains a NeuralNetwork, an ActionAdapter and a Distribution Component.
    """
    def __init__(self,
                 network_spec,
                 action_space=None,
                 action_adapter_spec=None,
                 max_likelihood=True,
                 scope="policy",
                 **kwargs):
        """
        Args:
            network_spec (Union[NeuralNetwork,dict]): The NeuralNetwork Component or a specification dict to build
                one.

            action_space (Space): The action Space within which this Component will create actions.

            action_adapter_spec (Optional[dict]): A spec-dict to create an ActionAdapter. Use None for the default
                ActionAdapter object.

            max_likelihood (bool): Whether to pick actions according to the max-likelihood value or via sampling.
                Default: True.
        """
        super(Policy, self).__init__(scope=scope, **kwargs)

        self.neural_network = NeuralNetwork.from_spec(network_spec)
        if action_space is None:
            self.action_adapter = ActionAdapter.from_spec(action_adapter_spec)
            action_space = self.action_adapter.action_space
        else:
            self.action_adapter = ActionAdapter.from_spec(
                action_adapter_spec, action_space=action_space)
        self.action_space = action_space
        self.max_likelihood = max_likelihood

        # TODO: Hacky trick to implement IMPALA post-LSTM256 time-rank folding and unfolding.
        # TODO: Replace entirely via sonnet-like BatchApply Component.
        is_impala = "IMPALANetwork" in type(self.neural_network).__name__

        # Add API-method to get baseline output (if we use an extra value function baseline node).
        if isinstance(self.action_adapter, BaselineActionAdapter):
            # TODO: IMPALA attempt to speed up final pass after LSTM.
            if is_impala:
                self.time_rank_folder = ReShape(fold_time_rank=True,
                                                scope="time-rank-fold")
                self.time_rank_unfolder_v = ReShape(unfold_time_rank=True,
                                                    time_major=True,
                                                    scope="time-rank-unfold-v")
                self.time_rank_unfolder_a_probs = ReShape(
                    unfold_time_rank=True,
                    time_major=True,
                    scope="time-rank-unfold-a-probs")
                self.time_rank_unfolder_logits = ReShape(
                    unfold_time_rank=True,
                    time_major=True,
                    scope="time-rank-unfold-logits")
                self.time_rank_unfolder_log_probs = ReShape(
                    unfold_time_rank=True,
                    time_major=True,
                    scope="time-rank-unfold-log-probs")
                self.add_components(self.time_rank_folder,
                                    self.time_rank_unfolder_v,
                                    self.time_rank_unfolder_a_probs,
                                    self.time_rank_unfolder_log_probs,
                                    self.time_rank_unfolder_logits)

            @rlgraph_api(component=self)
            def get_state_values_logits_probabilities_log_probs(
                    self, nn_input, internal_states=None):
                nn_output = self.neural_network.apply(nn_input,
                                                      internal_states)
                last_internal_states = nn_output.get("last_internal_states")
                nn_output = nn_output["output"]

                # TODO: IMPALA attempt to speed up final pass after LSTM.
                if is_impala:
                    nn_output = self.time_rank_folder.apply(nn_output)

                out = self.action_adapter.get_logits_probabilities_log_probs(
                    nn_output)

                # TODO: IMPALA attempt to speed up final pass after LSTM.
                if is_impala:
                    state_values = self.time_rank_unfolder_v.apply(
                        out["state_values"], nn_output)
                    logits = self.time_rank_unfolder_logits.apply(
                        out["logits"], nn_output)
                    probs = self.time_rank_unfolder_a_probs.apply(
                        out["probabilities"], nn_output)
                    log_probs = self.time_rank_unfolder_log_probs.apply(
                        out["log_probs"], nn_output)
                else:
                    state_values = out["state_values"]
                    logits = out["logits"]
                    probs = out["probabilities"]
                    log_probs = out["log_probs"]

                return dict(state_values=state_values,
                            logits=logits,
                            probabilities=probs,
                            log_probs=log_probs,
                            last_internal_states=last_internal_states)

        # Figure out our Distribution.
        if isinstance(action_space, IntBox):
            self.distribution = Categorical()
        # Continuous action space -> Normal distribution (each action needs mean and variance from network).
        elif isinstance(action_space, FloatBox):
            self.distribution = Normal()
        else:
            raise RLGraphError(
                "ERROR: `action_space` is of type {} and not allowed in {} Component!"
                .format(type(action_space).__name__, self.name))

        self.add_components(self.neural_network, self.action_adapter,
                            self.distribution)

        if is_impala:
            self.add_components(self.time_rank_folder,
                                self.time_rank_unfolder_v,
                                self.time_rank_unfolder_a_probs,
                                self.time_rank_unfolder_log_probs,
                                self.time_rank_unfolder_logits)

    # Define our interface.
    @rlgraph_api
    def get_nn_output(self, nn_input, internal_states=None):
        """
        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            any: The raw output of the neural network (before it's cleaned-up and passed through the ActionAdapter).
        """
        out = self.neural_network.apply(nn_input, internal_states)
        return dict(output=out["output"],
                    last_internal_states=out.get("last_internal_states"))

    @rlgraph_api
    def get_action(self, nn_input, internal_states=None, max_likelihood=None):
        """
        Returns an action based on NN output, action adapter output and distribution sampling.

        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.
            max_likelihood (Optional[bool]): If not None, use this to determine whether actions should be drawn
                from the distribution in max-likelihood or stochastic fashion.

        Returns:
            any: The drawn action.
        """
        max_likelihood = self.max_likelihood if max_likelihood is None else max_likelihood

        nn_output = self.get_nn_output(nn_input, internal_states)

        # Skip our distribution, iff discrete action-space and max-likelihood acting (greedy).
        # In that case, one does not need to create a distribution in the graph each act (only to get the argmax
        # over the logits, which is the same as the argmax over the probabilities (or log-probabilities)).
        if max_likelihood is True and isinstance(self.action_space, IntBox):
            out = self.action_adapter.get_logits_probabilities_log_probs(
                nn_output["output"])
            action = self._graph_fn_get_max_likelihood_action_wo_distribution(
                out["logits"])
        else:
            out = self.action_adapter.get_logits_probabilities_log_probs(
                nn_output["output"])
            action = self.distribution.draw(out["probabilities"],
                                            max_likelihood)
        return dict(action=action,
                    last_internal_states=nn_output["last_internal_states"])

    @rlgraph_api
    def get_max_likelihood_action(self, nn_input, internal_states=None):
        """
        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            any: See `get_action`, but with max_likelihood force set to True.
        """
        out = self.get_logits_probabilities_log_probs(nn_input,
                                                      internal_states)

        if isinstance(self.action_space, IntBox):
            action = self._graph_fn_get_max_likelihood_action_wo_distribution(
                out["logits"])
        else:
            action = self.distribution.sample_deterministic(
                out["probabilities"])

        return dict(action=action,
                    last_internal_states=out["last_internal_states"])

    @rlgraph_api
    def get_stochastic_action(self, nn_input, internal_states=None):
        """
        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            any: See `get_action`, but with max_likelihood force set to False.
        """
        out = self.get_logits_probabilities_log_probs(nn_input,
                                                      internal_states)
        action = self.distribution.sample_stochastic(out["probabilities"])
        return dict(action=action,
                    last_internal_states=out["last_internal_states"])

    @rlgraph_api
    def get_action_layer_output(self, nn_input, internal_states=None):
        """
        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            any: The raw output of the action layer of the ActionAdapter (including possibly the last internal states
                of a RNN-based NN).
        """
        nn_output = self.get_nn_output(nn_input, internal_states)
        action_layer_output = self.action_adapter.get_action_layer_output(
            nn_output["output"])
        # Add last internal states to return value.
        return dict(output=action_layer_output["output"],
                    last_internal_states=nn_output["last_internal_states"])

    @rlgraph_api
    def get_logits_probabilities_log_probs(self,
                                           nn_input,
                                           internal_states=None):
        """
        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            Dict:
                logits: The (reshaped) logits from the ActionAdapter.
                probabilities: The probabilities gained from the softmaxed logits.
                log_probs: The log(probabilities) values.
        """
        nn_output = self.get_nn_output(nn_input, internal_states)
        aa_output = self.action_adapter.get_logits_probabilities_log_probs(
            nn_output["output"])
        return dict(logits=aa_output["logits"],
                    probabilities=aa_output["probabilities"],
                    log_probs=aa_output["log_probs"],
                    last_internal_states=nn_output["last_internal_states"])

    @rlgraph_api
    def get_entropy(self, nn_input, internal_states=None):
        """
        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            any: See Distribution component.
        """
        out = self.get_logits_probabilities_log_probs(nn_input,
                                                      internal_states)
        entropy = self.distribution.entropy(out["probabilities"])
        return dict(entropy=entropy,
                    last_internal_states=out["last_internal_states"])

    @graph_fn
    def _graph_fn_get_max_likelihood_action_wo_distribution(self, logits):
        """
        Use this function only for discrete action spaces to circumvent using a full-blown
        backend-specific distribution object (e.g. tf.distribution.Multinomial).

        Args:
            logits (SingleDataOp): Logits over which to pick the argmax (greedy action).

        Returns:
            SingleDataOp: The argmax over the last rank of the input logits.
        """
        if get_backend() == "tf":
            return tf.argmax(logits, axis=-1, output_type=tf.int32)
        elif get_backend() == "pytorch":
            return torch.argmax(logits, dim=-1).int()