Ejemplo n.º 1
0
class DummyNNWithDictInput(NeuralNetwork):
    """
    Dummy NN with dict input taking a dict with keys "a" and "b" passes them both through two different (parallel,
    not connected in any way) dense layers and then concatenating the outputs to yield the final output.
    """
    def __init__(self,
                 num_units_a=3,
                 num_units_b=2,
                 scope="dummy-nn-with-dict-input",
                 **kwargs):
        super(DummyNNWithDictInput, self).__init__(scope=scope, **kwargs)

        self.num_units_a = num_units_a
        self.num_units_b = num_units_b

        # Splits the input into two streams.
        self.splitter = ContainerSplitter("a", "b")
        self.stack_a = DenseLayer(units=self.num_units_a, scope="dense-a")
        self.stack_b = DenseLayer(units=self.num_units_b, scope="dense-b")
        self.concat_layer = ConcatLayer()

        # Add all sub-components to this one.
        self.add_components(self.splitter, self.stack_a, self.stack_b,
                            self.concat_layer)

    @rlgraph_api
    def apply(self, input_dict):
        # Split the input dict into two streams.
        input_a, input_b = self.splitter.split(input_dict)

        # Get the two stack outputs.
        output_a = self.stack_a.apply(input_a)
        output_b = self.stack_b.apply(input_b)

        # Concat everything together, that's the output.
        concatenated_data = self.concat_layer.apply(output_a, output_b)

        return dict(output=concatenated_data)
Ejemplo n.º 2
0
class ActionAdapter(Component):
    """
    A Component that cleans up a neural network's flat output and gets it ready for parameterizing a
    Distribution Component.
    Processing steps include:
    - Sending the raw, flattened NN output through a Dense layer whose number of units matches the flattened
    action space.
    - Reshaping (according to the action Space).
    - Translating the reshaped outputs (logits) into probabilities (by softmaxing) and log-probabilities (log).
    """
    def __init__(self,
                 action_space,
                 add_units=0,
                 units=None,
                 weights_spec=None,
                 biases_spec=None,
                 activation=None,
                 scope="action-adapter",
                 **kwargs):
        """
        Args:
            action_space (Space): The action Space within which this Component will create actions.

            add_units (Optional[int]): An optional number of units to add to the auto-calculated number of action-
                layer nodes. Can be negative to subtract units from the auto-calculated value.
                NOTE: Only one of either `add_units` or `units` must be provided.

            units (Optional[int]): An optional number of units to use for the action-layer. If None, will calculate
                the number of units automatically from the given action_space.
                NOTE: Only one of either `add_units` or `units` must be provided.

            weights_spec (Optional[any]): An optional RLGraph Initializer spec that will be used to initialize the
                weights of `self.action layer`. Default: None (use default initializer).

            biases_spec (Optional[any]): An optional RLGraph Initializer spec that will be used to initialize the
                biases of `self.action layer`. Default: None (use default initializer, which is usually 0.0).

            activation (Optional[str]): The activation function to use for `self.action_layer`.
                Default: None (=linear).
        """
        super(ActionAdapter, self).__init__(scope=scope, **kwargs)

        self.action_space = action_space.with_batch_rank()
        self.weights_spec = weights_spec
        self.biases_spec = biases_spec
        self.activation = activation

        # Our (dense) action layer representing the flattened action space.
        self.action_layer = None

        # Calculate the number of nodes in the action layer (DenseLayer object) depending on our action Space
        # or using a given fixed number (`units`).
        # Also generate the ReShape sub-Component and give it the new_shape.
        if isinstance(self.action_space, IntBox):
            if units is None:
                units = add_units + self.action_space.flat_dim_with_categories
            self.reshape = ReShape(
                new_shape=self.action_space.get_shape(with_category_rank=True),
                flatten_categories=False)
        else:
            if units is None:
                units = add_units + 2 * self.action_space.flat_dim  # Those two dimensions are the mean and log sd
            # Manually add moments after batch/time ranks.
            new_shape = tuple([2] + list(self.action_space.shape))
            self.reshape = ReShape(new_shape=new_shape)

        assert units > 0, "ERROR: Number of nodes for action-layer calculated as {}! Must be larger 0.".format(
            units)

        # Create the action-layer and add it to this component.
        self.action_layer = DenseLayer(units=units,
                                       activation=self.activation,
                                       weights_spec=self.weights_spec,
                                       biases_spec=self.biases_spec,
                                       scope="action-layer")

        self.add_components(self.action_layer, self.reshape)

    def check_input_spaces(self, input_spaces, action_space=None):
        # Check the input Space.
        last_nn_layer_space = input_spaces["nn_output"]  # type: Space
        sanity_check_space(last_nn_layer_space,
                           non_allowed_types=[ContainerSpace])

        # Check the action Space.
        sanity_check_space(self.action_space, must_have_batch_rank=True)
        if isinstance(self.action_space, IntBox):
            sanity_check_space(self.action_space, must_have_categories=True)
        else:
            # Fixme: Are there other restraints on continuous action spaces? E.g. no dueling layers?
            pass

    @rlgraph_api
    def get_action_layer_output(self, nn_output):
        """
        Returns the raw, non-reshaped output of the action-layer (DenseLayer) after passing through it the raw
        nn_output (coming from the previous Component).

        Args:
            nn_output (DataOpRecord): The NN output of the preceding neural network.

        Returns:
            DataOpRecord: The output of the action layer (a DenseLayer) after passing `nn_output` through it.
        """
        out = self.action_layer.apply(nn_output)
        return dict(output=out)

    @rlgraph_api
    def get_logits(self, nn_output):
        """
        Args:
            nn_output (DataOpRecord): The NN output of the preceding neural network.

        Returns:
            SingleDataOp: The logits (raw nn_output, BUT reshaped).
        """
        aa_output = self.get_action_layer_output(nn_output)
        logits = self.reshape.apply(aa_output["output"])
        return logits

    @rlgraph_api
    def get_logits_probabilities_log_probs(self, nn_output):
        """
        Args:
            nn_output (DataOpRecord): The NN output of the preceding neural network.

        Returns:
            Tuple[SingleDataOp]:
                - logits (raw nn_output, BUT reshaped)
                - probabilities (softmaxed(logits))
                - log(probabilities)
        """
        logits = self.get_logits(nn_output)
        probabilities, log_probs = self._graph_fn_get_probabilities_log_probs(
            logits)
        return dict(logits=logits,
                    probabilities=probabilities,
                    log_probs=log_probs)

    # TODO: Use a SoftMax Component instead (uses the same code as the one below).
    @graph_fn
    def _graph_fn_get_probabilities_log_probs(self, logits):
        """
        Creates properties/parameters and log-probs from some reshaped output.

        Args:
            logits (SingleDataOp): The output of some layer that is already reshaped
                according to our action Space.

        Returns:
            tuple (2x SingleDataOp):
                parameters (DataOp): The parameters, ready to be passed to a Distribution object's
                    get_distribution API-method (usually some probabilities or loc/scale pairs).
                log_probs (DataOp): Simply the log(parameters).
        """
        if get_backend() == "tf":
            if isinstance(self.action_space, IntBox):
                # Discrete actions.
                parameters = tf.maximum(x=tf.nn.softmax(logits=logits,
                                                        axis=-1),
                                        y=SMALL_NUMBER)
                # Log probs.
                log_probs = tf.log(x=parameters)
            elif isinstance(self.action_space, FloatBox):
                # Continuous actions.
                mean, log_sd = tf.split(value=logits,
                                        num_or_size_splits=2,
                                        axis=1)
                # Remove moments rank.
                mean = tf.squeeze(input=mean, axis=1)
                log_sd = tf.squeeze(input=log_sd, axis=1)

                # Clip log_sd. log(SMALL_NUMBER) is negative.
                log_sd = tf.clip_by_value(t=log_sd,
                                          clip_value_min=log(SMALL_NUMBER),
                                          clip_value_max=-log(SMALL_NUMBER))

                # Turn log sd into sd.
                sd = tf.exp(x=log_sd)

                parameters = DataOpTuple(mean, sd)
                log_probs = DataOpTuple(tf.log(x=mean), log_sd)
            else:
                raise NotImplementedError

            return parameters, log_probs

        elif get_backend() == "pytorch":
            if isinstance(self.action_space, IntBox):
                # Discrete actions.
                softmax_logits = torch.softmax(logits, dim=-1)
                parameters = torch.max(softmax_logits, SMALL_NUMBER_TORCH)
                # Log probs.
                log_probs = torch.log(parameters)
            elif isinstance(self.action_space, FloatBox):
                # Continuous actions.
                mean, log_sd = torch.split(logits,
                                           split_size_or_sections=2,
                                           dim=1)
                # Remove moments rank.
                mean = torch.squeeze(mean, dim=1)
                log_sd = torch.squeeze(log_sd, dim=1)

                # Clip log_sd. log(SMALL_NUMBER) is negative.
                log_sd = torch.clamp(log_sd,
                                     min=LOG_SMALL_NUMBER,
                                     max=-LOG_SMALL_NUMBER)

                # Turn log sd into sd.
                sd = torch.exp(log_sd)

                parameters = DataOpTuple(mean, sd)
                log_probs = DataOpTuple(torch.log(mean), log_sd)
            else:
                raise NotImplementedError

            return parameters, log_probs
Ejemplo n.º 3
0
class DuelingActionAdapter(ActionAdapter):
    """
    An ActionAdapter that adds a dueling Q calculation to the flattened output of a neural network.

    API:
        get_dueling_output(nn_output) (Tuple[SingleDataOp x 3]): The state-value, advantage-values
            (reshaped) and q-values (reshaped) after passing action_layer_output through the dueling layer.
    """
    def __init__(self,
                 units_state_value_stream,
                 units_advantage_stream,
                 weights_spec_state_value_stream=None,
                 biases_spec_state_value_stream=None,
                 activation_state_value_stream="relu",
                 weights_spec_advantage_stream=None,
                 biases_spec_advantage_stream=None,
                 activation_advantage_stream="relu",
                 scope="dueling-action-adapter",
                 **kwargs):
        # TODO: change add_units=-1 once we have a true base class for action-adapters.
        super(DuelingActionAdapter, self).__init__(add_units=0,
                                                   scope=scope,
                                                   **kwargs)

        # The state-value stream.
        self.units_state_value_stream = units_state_value_stream
        self.weights_spec_state_value_stream = weights_spec_state_value_stream
        self.biases_spec_state_value_stream = biases_spec_state_value_stream
        self.activation_state_value_stream = activation_state_value_stream

        # The advantage stream.
        self.units_advantage_stream = units_advantage_stream
        self.weights_spec_advantage_stream = weights_spec_advantage_stream
        self.biases_spec_advantage_stream = biases_spec_advantage_stream
        self.activation_advantage_stream = activation_advantage_stream

        # Create all 4 extra DenseLayers.
        self.dense_layer_state_value_stream = DenseLayer(
            units=self.units_state_value_stream,
            weights_spec=self.weights_spec_state_value_stream,
            biases_spec=self.biases_spec_state_value_stream,
            activation=self.activation_state_value_stream,
            scope="dense-layer-state-value-stream")
        self.dense_layer_advantage_stream = DenseLayer(
            units=self.units_state_value_stream,
            weights_spec=self.weights_spec_state_value_stream,
            biases_spec=self.biases_spec_state_value_stream,
            activation=self.activation_state_value_stream,
            scope="dense-layer-advantage-stream")
        self.state_value_node = DenseLayer(units=1,
                                           activation="linear",
                                           scope="state-value-node")
        # self.action_layer is our advantage layer

        self.add_components(self.dense_layer_state_value_stream,
                            self.dense_layer_advantage_stream,
                            self.state_value_node)

    @rlgraph_api
    def get_action_layer_output(self, nn_output):
        """
        Args:
            nn_output (DataOpRecord): The NN output of the preceding neural network.

        Returns:
            tuple:
                DataOpRecord: The output of the state-value stream (a DenseLayer) after passing `nn_output` through it.

                DataOpRecord: The output of the advantage-value stream (a DenseLayer) after passing `nn_output` through
                    it. Note: These will be flat advantage nodes that have not been reshaped yet according to the
                    action_space.
        """
        output_state_value_dense = self.dense_layer_state_value_stream.apply(
            nn_output)
        output_advantage_dense = self.dense_layer_advantage_stream.apply(
            nn_output)
        state_value_node = self.state_value_node.apply(
            output_state_value_dense)
        advantage_nodes = self.action_layer.apply(output_advantage_dense)
        return dict(state_value_node=state_value_node, output=advantage_nodes)

    @rlgraph_api
    def get_logits_probabilities_log_probs(self, nn_output):
        """
        Args:
            nn_output (DataOpRecord): The NN output of the preceding neural network.

        Returns:
            tuple (4x DataOpRecord):
                - The single state value node output.
                - The (already reshaped) q-values (the logits).
                - The probabilities obtained by softmaxing the q-values.
                - The log-probs.
        """
        out = self.get_action_layer_output(nn_output)
        advantage_values_reshaped = self.reshape.apply(out["output"])
        q_values = self._graph_fn_calculate_q_values(
            out["state_value_node"], advantage_values_reshaped)
        probabilities, log_probs = self._graph_fn_get_probabilities_log_probs(
            q_values)
        return dict(state_values=out["state_value_node"],
                    logits=q_values,
                    probabilities=probabilities,
                    log_probs=log_probs)

    @graph_fn
    def _graph_fn_calculate_q_values(self, state_value, advantage_values):
        """
        Args:
            state_value (SingleDataOp): The single node state-value output.
            advantage_values (SingleDataOp): The already reshaped advantage-values.

        Returns:
            SingleDataOp: The calculated, reshaped Q values (for each composite action) based on:
                Q = V + [A - mean(A)]
        """
        # Use the very first node as value function output.
        # Use all following nodes as advantage function output.
        if get_backend() == "tf":
            ## Separate out the single state-value node.
            #state_value, advantages = tf.split(
            #    value=inputs, num_or_size_splits=(1, self.num_advantage_values), axis=-1
            #)
            # Now we have to reshape the advantages according to our action space.
            #shape = list(self.target_space.get_shape(with_batch_rank=-1, with_category_rank=True))
            #advantages = tf.reshape(tensor=advantage_values, shape=shape)
            # Calculate the q-values according to [1] and return.
            mean_advantages = tf.reduce_mean(input_tensor=advantage_values,
                                             axis=-1,
                                             keepdims=True)

            # Make sure we broadcast the state_value correctly for the upcoming q_value calculation.
            state_value_expanded = state_value
            for _ in range(get_rank(advantage_values) - 2):
                state_value_expanded = tf.expand_dims(state_value_expanded,
                                                      axis=1)
            q_values = state_value_expanded + advantage_values - mean_advantages

            ## state-value, advantages, q_values
            # q-values
            return q_values
            #tf.squeeze(state_value, axis=-1), advantages,
        elif get_backend() == "pytorch":
            mean_advantages = torch.mean(advantage_values,
                                         dim=-1,
                                         keepdim=True)

            # Make sure we broadcast the state_value correctly for the upcoming q_value calculation.
            state_value_expanded = state_value
            for _ in range(get_rank(advantage_values) - 2):
                state_value_expanded = torch.unsqueeze(state_value_expanded,
                                                       dim=1)
            q_values = state_value_expanded + advantage_values - mean_advantages

            ## state-value, advantages, q_values
            # q-values
            return q_values

    # TODO: Use a SoftMax Component instead (uses the same code as the one below).
    @graph_fn
    def _graph_fn_get_probabilities_log_probs(self, logits):
        """
        Creates properties/parameters and log-probs from some reshaped output.

        Args:
            logits (SingleDataOp): The output of some layer that is already reshaped
                according to our action Space.

        Returns:
            tuple (2x SingleDataOp):
                parameters (DataOp): The parameters, ready to be passed to a Distribution object's
                    get_distribution API-method (usually some probabilities or loc/scale pairs).

                log_probs (DataOp): Simply the log(parameters).
        """
        if get_backend() == "tf":
            if isinstance(self.action_space, IntBox):
                # Discrete actions.
                parameters = tf.maximum(x=tf.nn.softmax(logits=logits,
                                                        axis=-1),
                                        y=SMALL_NUMBER)
                # Log probs.
                log_probs = tf.log(x=parameters)
            elif isinstance(self.action_space, FloatBox):
                # Continuous actions.
                mean, log_sd = tf.split(value=logits,
                                        num_or_size_splits=2,
                                        axis=1)
                # Remove moments rank.
                mean = tf.squeeze(input=mean, axis=1)
                log_sd = tf.squeeze(input=log_sd, axis=1)

                # Clip log_sd. log(SMALL_NUMBER) is negative.
                log_sd = tf.clip_by_value(
                    t=log_sd,
                    clip_value_min=math.log(SMALL_NUMBER),
                    clip_value_max=-math.log(SMALL_NUMBER))

                # Turn log sd into sd.
                sd = tf.exp(x=log_sd)

                parameters = DataOpTuple(mean, sd)
                log_probs = DataOpTuple(tf.log(x=mean), log_sd)
            else:
                raise NotImplementedError
            return parameters, log_probs

        elif get_backend() == "pytorch":
            if isinstance(self.action_space, IntBox):
                # Discrete actions.
                parameters = torch.max(torch.softmax(logits, dim=-1),
                                       torch.tensor(SMALL_NUMBER))
                # Log probs.
                log_probs = torch.log(parameters)
            elif isinstance(self.action_space, FloatBox):
                # Continuous actions.
                mean, log_sd = torch.split(logits,
                                           split_size_or_sections=2,
                                           dim=1)
                # Remove moments rank.
                mean = torch.squeeze(mean, dim=1)
                log_sd = torch.squeeze(log_sd, dim=1)

                # Clip log_sd. log(SMALL_NUMBER) is negative.
                log_sd = torch.clamp(log_sd,
                                     min=math.log(SMALL_NUMBER),
                                     max=-math.log(SMALL_NUMBER))

                # Turn log sd into sd.
                sd = torch.exp(log_sd)

                parameters = DataOpTuple(mean, sd)
                log_probs = DataOpTuple(torch.log(mean), log_sd)
            else:
                raise NotImplementedError

            return parameters, log_probs
Ejemplo n.º 4
0
class DuelingPolicy(Policy):
    def __init__(self,
                 network_spec,
                 units_state_value_stream,
                 weights_spec_state_value_stream=None,
                 biases_spec_state_value_stream=None,
                 activation_state_value_stream="relu",
                 scope="dueling-policy",
                 **kwargs):
        super(DuelingPolicy, self).__init__(network_spec,
                                            scope=scope,
                                            **kwargs)

        self.action_space_flattened = self.action_space.flatten()

        # The state-value stream.
        self.units_state_value_stream = units_state_value_stream
        self.weights_spec_state_value_stream = weights_spec_state_value_stream
        self.biases_spec_state_value_stream = biases_spec_state_value_stream
        self.activation_state_value_stream = activation_state_value_stream

        # Our softmax component to produce probabilities.
        self.softmax = Softmax()

        # Create all state value extra Layers.
        # TODO: Make this a NN-spec as well (right now it's one layer fixed plus the final value node).
        self.dense_layer_state_value_stream = DenseLayer(
            units=self.units_state_value_stream,
            weights_spec=self.weights_spec_state_value_stream,
            biases_spec=self.biases_spec_state_value_stream,
            activation=self.activation_state_value_stream,
            scope="dense-layer-state-value-stream")
        self.state_value_node = DenseLayer(units=1,
                                           activation="linear",
                                           scope="state-value-node")

        self.add_components(self.dense_layer_state_value_stream,
                            self.state_value_node)

    @rlgraph_api
    def get_state_values(self, nn_input, internal_states=None):
        """
        Returns the state value node's output passing some nn-input through the policy and the state-value
        stream.

        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            Dict:
                state_values: The single (but batched) value function node output.
        """
        nn_output = self.get_nn_output(nn_input, internal_states)
        state_values_tmp = self.dense_layer_state_value_stream.apply(
            nn_output["output"])
        state_values = self.state_value_node.apply(state_values_tmp)

        return dict(state_values=state_values,
                    last_internal_states=nn_output.get("last_internal_states"))

    @rlgraph_api
    def get_state_values_logits_parameters_log_probs(self,
                                                     nn_input,
                                                     internal_states=None):
        """
        Similar to `get_values_logits_probabilities_log_probs`, but also returns in the return dict under key
        `state_value` the output of our state-value function node.

        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            Dict:
                state_values: The single (but batched) value function node output.
                logits: The (reshaped) logits from the ActionAdapter.
                parameters: The parameters for the distribution (gained from the softmaxed logits or interpreting
                    logits as mean and stddev for a normal distribution).
                log_probs: The log(probabilities) values.
                last_internal_states: The last internal states (if network is RNN-based).
        """
        nn_output = self.get_nn_output(nn_input, internal_states)
        advantages, _, _ = self._graph_fn_get_action_adapter_logits_parameters_log_probs(
            nn_output["output"], nn_input)
        state_values_tmp = self.dense_layer_state_value_stream.apply(
            nn_output["output"])
        state_values = self.state_value_node.apply(state_values_tmp)

        q_values = self._graph_fn_calculate_q_values(state_values, advantages)

        parameters, log_probs = self._graph_fn_get_parameters_log_probs(
            q_values)

        return dict(state_values=state_values,
                    logits=q_values,
                    parameters=parameters,
                    log_probs=log_probs,
                    last_internal_states=nn_output.get("last_internal_states"),
                    advantages=advantages,
                    q_values=q_values)

    @rlgraph_api
    def get_state_values_logits_probabilities_log_probs(
            self, nn_input, internal_states=None):
        """
        Similar to `get_values_logits_probabilities_log_probs`, but also returns in the return dict under key
        `state_value` the output of our state-value function node.

        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            Dict:
                state_values: The single (but batched) value function node output.
                logits: The (reshaped) logits from the ActionAdapter.
                probabilities: The probabilities gained from the softmaxed logits.
                log_probs: The log(probabilities) values.
                last_internal_states: The last internal states (if network is RNN-based).
        """
        self.logger.warn(
            "Deprecated API method `get_state_values_logits_probabilities_log_probs` used!"
            "Use `get_state_values_logits_parameters_log_probs` instead.")

        nn_output = self.get_nn_output(nn_input, internal_states)
        advantages, _, _ = self._graph_fn_get_action_adapter_logits_parameters_log_probs(
            nn_output["output"], nn_input)
        state_values_tmp = self.dense_layer_state_value_stream.apply(
            nn_output["output"])
        state_values = self.state_value_node.apply(state_values_tmp)

        q_values = self._graph_fn_calculate_q_values(state_values, advantages)

        parameters, log_probs = self._graph_fn_get_parameters_log_probs(
            q_values)

        return dict(state_values=state_values,
                    logits=q_values,
                    probabilities=parameters,
                    parameters=parameters,
                    log_probs=log_probs,
                    last_internal_states=nn_output.get("last_internal_states"),
                    advantages=advantages,
                    q_values=q_values)

    @rlgraph_api
    def get_logits_parameters_log_probs(self, nn_input, internal_states=None):
        """
        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            Dict:
                logits: The q-values after adding advantages to state values (and subtracting the mean advantage).
                parameters: The parameters for the distribution (gained from the softmaxed logits or interpreting
                    logits as mean and stddev for a normal distribution).
                log_probs: The log(probabilities) values.
                last_internal_states: The final internal states after passing through a possible RNN.
        """
        out = self.get_state_values_logits_parameters_log_probs(
            nn_input, internal_states)
        return dict(logits=out["logits"],
                    parameters=out["parameters"],
                    log_probs=out["log_probs"],
                    last_internal_states=out.get("last_internal_states"))

    @rlgraph_api
    def get_logits_probabilities_log_probs(self,
                                           nn_input,
                                           internal_states=None):
        """
        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            Dict:
                logits: The q-values after adding advantages to state values (and subtracting the mean advantage).
                probabilities: The probabilities gained from the softmaxed logits.
                log_probs: The log(probabilities) values.
                last_internal_states: The final internal states after passing through a possible RNN.
        """
        self.logger.warn(
            "Deprecated API method `get_logits_probabilities_log_probs` used!"
            "Use `get_logits_parameters_log_probs` instead.")
        out = self.get_state_values_logits_parameters_log_probs(
            nn_input, internal_states)
        return dict(logits=out["logits"],
                    probabilities=out["parameters"],
                    parameters=out["parameters"],
                    log_probs=out["log_probs"],
                    last_internal_states=out.get("last_internal_states"))

    @graph_fn(flatten_ops=True, split_ops=True)
    def _graph_fn_calculate_q_values(self, state_value, advantage_values):
        """
        Args:
            state_value (SingleDataOp): The single node state-value output.
            advantage_values (SingleDataOp): The already reshaped advantage-values.

        Returns:
            SingleDataOp: The calculated, reshaped Q values (for each composite action) based on:
                Q = V + [A - mean(A)]
        """
        # Use the very first node as value function output.
        # Use all following nodes as advantage function output.
        if get_backend() == "tf":
            # Calculate the q-values according to [1] and return.
            mean_advantages = tf.reduce_mean(input_tensor=advantage_values,
                                             axis=-1,
                                             keepdims=True)

            # Make sure we broadcast the state_value correctly for the upcoming q_value calculation.
            state_value_expanded = state_value
            for _ in range(get_rank(advantage_values) - 2):
                state_value_expanded = tf.expand_dims(state_value_expanded,
                                                      axis=1)
            q_values = state_value_expanded + advantage_values - mean_advantages

            # q-values
            return q_values

        elif get_backend() == "pytorch":
            mean_advantages = torch.mean(advantage_values,
                                         dim=-1,
                                         keepdim=True)

            # Make sure we broadcast the state_value correctly for the upcoming q_value calculation.
            state_value_expanded = state_value
            for _ in range(get_rank(advantage_values) - 2):
                state_value_expanded = torch.unsqueeze(state_value_expanded,
                                                       dim=1)
            q_values = state_value_expanded + advantage_values - mean_advantages

            # q-values
            return q_values

    @graph_fn(flatten_ops=True,
              split_ops=True,
              add_auto_key_as_first_param=True)
    def _graph_fn_get_parameters_log_probs(self, key, logits):
        """
        Creates parameters and log-probs from some reshaped output.

        Args:
            logits (SingleDataOp): The output of some layer that is already reshaped
                according to our action Space.

        Returns:
            tuple (2x SingleDataOp):
                parameters (DataOp): The parameters, ready to be passed to a Distribution object's
                    get_distribution API-method (usually some probabilities or loc/scale pairs).

                log_probs (DataOp): Simply the log(parameters).
        """

        if get_backend() == "tf":
            if isinstance(self.action_space_flattened[key], IntBox):
                # Discrete actions.
                parameters = tf.maximum(x=tf.nn.softmax(logits=logits,
                                                        axis=-1),
                                        y=SMALL_NUMBER)
                # Log probs.
                log_probs = tf.log(x=parameters)
            elif isinstance(self.action_space_flattened[key], FloatBox):
                # Continuous actions.
                mean, log_sd = tf.split(value=logits,
                                        num_or_size_splits=2,
                                        axis=1)
                # Remove moments rank.
                mean = tf.squeeze(input=mean, axis=1)
                log_sd = tf.squeeze(input=log_sd, axis=1)

                # Clip log_sd. log(SMALL_NUMBER) is negative.
                log_sd = tf.clip_by_value(
                    t=log_sd,
                    clip_value_min=math.log(SMALL_NUMBER),
                    clip_value_max=-math.log(SMALL_NUMBER))

                # Turn log sd into sd.
                sd = tf.exp(x=log_sd)

                parameters = DataOpTuple(mean, sd)
                log_probs = DataOpTuple(tf.log(x=mean), log_sd)
            else:
                raise NotImplementedError
            return parameters, log_probs

        elif get_backend() == "pytorch":
            if isinstance(self.action_space, IntBox):
                # Discrete actions.
                parameters = torch.max(torch.softmax(logits, dim=-1),
                                       torch.tensor(SMALL_NUMBER))
                # Log probs.
                log_probs = torch.log(parameters)
            elif isinstance(self.action_space, FloatBox):
                # Continuous actions.
                mean, log_sd = torch.split(logits,
                                           split_size_or_sections=2,
                                           dim=1)
                # Remove moments rank.
                mean = torch.squeeze(mean, dim=1)
                log_sd = torch.squeeze(log_sd, dim=1)

                # Clip log_sd. log(SMALL_NUMBER) is negative.
                log_sd = torch.clamp(log_sd,
                                     min=math.log(SMALL_NUMBER),
                                     max=-math.log(SMALL_NUMBER))

                # Turn log sd into sd.
                sd = torch.exp(log_sd)

                parameters = DataOpTuple(mean, sd)
                log_probs = DataOpTuple(torch.log(mean), log_sd)
            else:
                raise NotImplementedError

            return parameters, log_probs
Ejemplo n.º 5
0
class DuelingPolicy(Policy):
    def __init__(self,
                 network_spec,
                 units_state_value_stream,
                 weights_spec_state_value_stream=None,
                 biases_spec_state_value_stream=None,
                 activation_state_value_stream="relu",
                 scope="dueling-policy",
                 **kwargs):
        super(DuelingPolicy, self).__init__(network_spec,
                                            scope=scope,
                                            **kwargs)

        self.action_space_flattened = self.action_space.flatten()

        # The state-value stream.
        self.units_state_value_stream = units_state_value_stream
        self.weights_spec_state_value_stream = weights_spec_state_value_stream
        self.biases_spec_state_value_stream = biases_spec_state_value_stream
        self.activation_state_value_stream = activation_state_value_stream

        # Our softmax component to produce probabilities.
        self.softmax = Softmax()

        # Create all state value extra Layers.
        # TODO: Make this a NN-spec as well (right now it's one layer fixed plus the final value node).
        self.dense_layer_state_value_stream = DenseLayer(
            units=self.units_state_value_stream,
            weights_spec=self.weights_spec_state_value_stream,
            biases_spec=self.biases_spec_state_value_stream,
            activation=self.activation_state_value_stream,
            scope="dense-layer-state-value-stream")
        self.state_value_node = DenseLayer(units=1,
                                           activation="linear",
                                           scope="state-value-node")

        self.add_components(self.dense_layer_state_value_stream,
                            self.state_value_node)

    @rlgraph_api
    def get_state_values(self, nn_input, internal_states=None):
        """
        Returns the state value node's output passing some nn-input through the policy and the state-value
        stream.

        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            Dict:
                state_values: The single (but batched) value function node output.
        """
        nn_output = self.get_nn_output(nn_input, internal_states)
        state_values_tmp = self.dense_layer_state_value_stream.apply(
            nn_output["output"])
        state_values = self.state_value_node.apply(state_values_tmp)

        return dict(state_values=state_values,
                    last_internal_states=nn_output.get("last_internal_states"))

    @rlgraph_api
    def get_state_values_logits_parameters_log_probs(self,
                                                     nn_input,
                                                     internal_states=None):
        """
        Similar to `get_values_logits_probabilities_log_probs`, but also returns in the return dict under key
        `state_value` the output of our state-value function node.

        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            Dict:
                state_values: The single (but batched) value function node output.
                logits: The (reshaped) logits from the ActionAdapter.
                parameters: The parameters for the distribution (gained from the softmaxed logits or interpreting
                    logits as mean and stddev for a normal distribution).
                log_probs: The log(probabilities) values.
                last_internal_states: The last internal states (if network is RNN-based).
        """
        nn_output = self.get_nn_output(nn_input, internal_states)
        advantages, _, _ = self._graph_fn_get_action_adapter_logits_parameters_log_probs(
            nn_output["output"], nn_input)
        state_values_tmp = self.dense_layer_state_value_stream.apply(
            nn_output["output"])
        state_values = self.state_value_node.apply(state_values_tmp)

        q_values = self._graph_fn_calculate_q_values(state_values, advantages)

        parameters, log_probs = self._graph_fn_get_action_adapter_parameters_log_probs(
            q_values)

        return dict(state_values=state_values,
                    logits=q_values,
                    parameters=parameters,
                    log_probs=log_probs,
                    last_internal_states=nn_output.get("last_internal_states"),
                    advantages=advantages,
                    q_values=q_values)

    def get_state_values_logits_probabilities_log_probs(
            self, nn_input, internal_states=None):
        raise RLGraphObsoletedError(
            "API method", "get_state_values_logits_probabilities_log_probs",
            "get_state_values_logits_parameters_log_probs")

    @rlgraph_api
    def get_logits_parameters_log_probs(self, nn_input, internal_states=None):
        """
        Args:
            nn_input (any): The input to our neural network.
            internal_states (Optional[any]): The initial internal states going into an RNN-based neural network.

        Returns:
            Dict:
                logits: The q-values after adding advantages to state values (and subtracting the mean advantage).
                parameters: The parameters for the distribution (gained from the softmaxed logits or interpreting
                    logits as mean and stddev for a normal distribution).
                log_probs: The log(probabilities) values.
                last_internal_states: The final internal states after passing through a possible RNN.
        """
        out = self.get_state_values_logits_parameters_log_probs(
            nn_input, internal_states)
        return dict(logits=out["logits"],
                    parameters=out["parameters"],
                    log_probs=out["log_probs"],
                    last_internal_states=out.get("last_internal_states"))

    def get_logits_probabilities_log_probs(self,
                                           nn_input,
                                           internal_states=None):
        raise RLGraphObsoletedError("API method",
                                    "get_logits_probabilities_log_probs",
                                    "get_logits_parameters_log_probs")

    @graph_fn(flatten_ops=True, split_ops=True)
    def _graph_fn_calculate_q_values(self, state_value, advantage_values):
        """
        Args:
            state_value (SingleDataOp): The single node state-value output.
            advantage_values (SingleDataOp): The already reshaped advantage-values.

        Returns:
            SingleDataOp: The calculated, reshaped Q values (for each composite action) based on:
                Q = V + [A - mean(A)]
        """
        # Use the very first node as value function output.
        # Use all following nodes as advantage function output.
        if get_backend() == "tf":
            # Calculate the q-values according to [1] and return.
            mean_advantages = tf.reduce_mean(input_tensor=advantage_values,
                                             axis=-1,
                                             keepdims=True)

            # Make sure we broadcast the state_value correctly for the upcoming q_value calculation.
            state_value_expanded = state_value
            for _ in range(get_rank(advantage_values) - 2):
                state_value_expanded = tf.expand_dims(state_value_expanded,
                                                      axis=1)
            q_values = state_value_expanded + advantage_values - mean_advantages

            # q-values
            return q_values

        elif get_backend() == "pytorch":
            mean_advantages = torch.mean(advantage_values,
                                         dim=-1,
                                         keepdim=True)

            # Make sure we broadcast the state_value correctly for the upcoming q_value calculation.
            state_value_expanded = state_value
            for _ in range(get_rank(advantage_values) - 2):
                state_value_expanded = torch.unsqueeze(state_value_expanded,
                                                       dim=1)
            q_values = state_value_expanded + advantage_values - mean_advantages

            # q-values
            return q_values

    @graph_fn(flatten_ops=True,
              split_ops=True,
              add_auto_key_as_first_param=True)
    def _graph_fn_get_action_adapter_parameters_log_probs(self, key, q_values):
        """
        """
        out = self.action_adapters[key].get_parameters_log_probs(q_values)
        return out["parameters"], out["log_probs"]