Ejemplo n.º 1
0
    def test_one_hot(self):
        """
        Tests a torch one hot function.
        """
        if get_backend() == "pytorch":
            # Flat action array.
            inputs = torch.tensor([0, 1], dtype=torch.int32)
            one_hot = pytorch_one_hot(inputs, depth=2)

            expected = torch.tensor([[1., 0.], [0., 1.]])
            recursive_assert_almost_equal(one_hot, expected)

            # Container space.
            inputs = torch.tensor([[0, 3, 2], [1, 2, 0]], dtype=torch.int32)
            one_hot = pytorch_one_hot(inputs, depth=4)

            expected = torch.tensor(
                [[[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]],
                 [[0, 1, 0, 0], [0, 0, 1, 0], [
                     1,
                     0,
                     0,
                     0,
                 ]]],
                dtype=torch.int32)
            recursive_assert_almost_equal(one_hot, expected)
Ejemplo n.º 2
0
    def _graph_fn_apply(self, key, preprocessing_inputs, input_before_time_rank_folding=None):
        """
        Reshapes the input to the specified new shape.

        Args:
            preprocessing_inputs (SingleDataOp): The input to reshape.
            input_before_time_rank_folding (Optional[SingleDataOp]): The original input (before!) the time-rank had
                been folded (this was done in a different ReShape Component). Serves if `self.unfold_time_rank` is True
                to figure out the exact time-rank dimension to unfold.

        Returns:
            SingleDataOp: The reshaped input.
        """
        assert self.unfold_time_rank is False or input_before_time_rank_folding is not None

        #preprocessing_inputs = tf.Print(preprocessing_inputs, [tf.shape(preprocessing_inputs)], summarize=1000,
        #                                message="input shape for {} (key={}): {}".format(preprocessing_inputs.name, key, self.scope))

        if self.backend == "python" or get_backend() == "python":
            # Create a one-hot axis for the categories at the end?
            if self.num_categories.get(key, 0) > 1:
                preprocessing_inputs = one_hot(preprocessing_inputs, depth=self.num_categories[key])

            new_shape = self.output_spaces[key].get_shape(
                with_batch_rank=-1, with_time_rank=-1, time_major=self.time_major
            )
            # Dynamic new shape inference:
            # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common
            # batch+time 0th rank, get these two dynamically.
            # Note: We may still flip the two, if input space has a different `time_major` than output space.
            if len(new_shape) > 2 and new_shape[0] == -1 and new_shape[1] == -1:
                # Time rank unfolding. Get the time rank from original input.
                if self.unfold_time_rank is True:
                    original_shape = input_before_time_rank_folding.shape
                    new_shape = (original_shape[0], original_shape[1]) + new_shape[2:]
                # No time-rank unfolding, but we do have both batch- and time-rank.
                else:
                    input_shape = preprocessing_inputs.shape
                    # Batch and time rank stay as is.
                    if self.time_major is None or self.time_major is self.in_space_time_majors[key]:
                        new_shape = (input_shape[0], input_shape[1]) + new_shape[2:]
                    # Batch and time rank need to be flipped around: Do a transpose.
                    else:
                        preprocessing_inputs = np.transpose(preprocessing_inputs, axes=(1, 0) + input_shape[2:])
                        new_shape = (input_shape[1], input_shape[0]) + new_shape[2:]

            return np.reshape(preprocessing_inputs, newshape=new_shape)
        elif get_backend() == "pytorch":
            # Create a one-hot axis for the categories at the end?
            if self.num_categories.get(key, 0) > 1:
                preprocessing_inputs = pytorch_one_hot(preprocessing_inputs, depth=self.num_categories[key])
            new_shape = self.output_spaces[key].get_shape(
                with_batch_rank=-1, with_time_rank=-1, time_major=self.time_major
            )
            # Dynamic new shape inference:
            # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common
            # batch+time 0th rank, get these two dynamically.
            # Note: We may still flip the two, if input space has a different `time_major` than output space.
            if len(new_shape) > 2 and new_shape[0] == -1 and new_shape[1] == -1:
                # Time rank unfolding. Get the time rank from original input.
                if self.unfold_time_rank is True:
                    original_shape = input_before_time_rank_folding.shape
                    new_shape = (original_shape[0], original_shape[1]) + new_shape[2:]
                # No time-rank unfolding, but we do have both batch- and time-rank.
                else:
                    input_shape = preprocessing_inputs.shape
                    # Batch and time rank stay as is.
                    if self.time_major is None or self.time_major is self.in_space_time_majors[key]:
                        new_shape = (input_shape[0], input_shape[1]) + new_shape[2:]
                    # Batch and time rank need to be flipped around: Do a transpose.
                    else:
                        preprocessing_inputs = torch.transpose(preprocessing_inputs, (1, 0) + input_shape[2:])
                        new_shape = (input_shape[1], input_shape[0]) + new_shape[2:]

            # print("Reshaping input of shape {} to new shape {} ".format(preprocessing_inputs.shape, new_shape))

            # The problem here is the following: Input has dim e.g. [4, 256, 1, 1]
            # -> If shape inference in spaces failed, output dim is not correct -> reshape will attempt
            # something like reshaping to [256].
            if self.flatten or (preprocessing_inputs.size(0) > 1 and preprocessing_inputs.dim() > 1):
                return preprocessing_inputs.squeeze()
            else:
                return torch.reshape(preprocessing_inputs, new_shape)

        elif get_backend() == "tf":
            # Create a one-hot axis for the categories at the end?
            if self.num_categories.get(key, 0) > 1:
                preprocessing_inputs = tf.one_hot(
                    preprocessing_inputs, depth=self.num_categories[key], axis=-1, dtype="float32"
                )

            new_shape = self.output_spaces[key].get_shape(
                with_batch_rank=-1, with_time_rank=-1, time_major=self.time_major
            )
            # Dynamic new shape inference:
            # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common
            # batch+time 0th rank, get these two dynamically.
            # Note: We may still flip the two, if input space has a different `time_major` than output space.
            flip_after_reshape = False
            if len(new_shape) >= 2 and new_shape[0] == -1 and new_shape[1] == -1:
                # Time rank unfolding. Get the time rank from original input (and maybe flip).
                if self.unfold_time_rank is True:
                    original_shape = tf.shape(input_before_time_rank_folding)
                    new_shape = (original_shape[0], original_shape[1]) + new_shape[2:]
                    flip_after_reshape = self.flip_batch_and_time_rank
                # No time-rank unfolding, but we do have both batch- and time-rank.
                else:
                    input_shape = tf.shape(preprocessing_inputs)
                    # Batch and time rank stay as is.
                    if self.time_major is None or self.time_major is self.in_space_time_majors[key]:
                        new_shape = (input_shape[0], input_shape[1]) + new_shape[2:]
                    # Batch and time rank need to be flipped around: Do a transpose.
                    else:
                        assert self.flip_batch_and_time_rank is True
                        preprocessing_inputs = tf.transpose(
                            preprocessing_inputs, perm=(1, 0) + tuple(i for i in range(
                                2, input_shape.shape.as_list()[0]
                            )), name="transpose-flip-batch-time-ranks"
                        )
                        new_shape = (input_shape[1], input_shape[0]) + new_shape[2:]

            reshaped = tf.reshape(tensor=preprocessing_inputs, shape=new_shape, name="reshaped")

            if flip_after_reshape and self.flip_batch_and_time_rank:
                reshaped = tf.transpose(reshaped, (1, 0) + tuple(i for i in range(2, len(new_shape))), name="transpose-flip-batch-time-ranks-after-reshape")

            #reshaped = tf.Print(reshaped, [tf.shape(reshaped)], summarize=1000,
            #                    message="output shape for {} (key={}): {}".format(reshaped, key, self.scope))

            # Have to place the time rank back in as unknown (for the auto Space inference).
            if type(self.unfold_time_rank) == int:
                # TODO: replace placeholder with default value by _batch_rank/_time_rank properties.
                return tf.placeholder_with_default(reshaped, shape=(None, None) + new_shape[2:])
            else:
                # TODO: add other cases of reshaping and fix batch/time rank hints.
                if self.fold_time_rank:
                    reshaped._batch_rank = 0
                elif self.unfold_time_rank or self.flip_batch_and_time_rank:
                    reshaped._batch_rank = 0 if self.time_major is False else 1
                    reshaped._time_rank = 0 if self.time_major is True else 1

                return reshaped
Ejemplo n.º 3
0
    def _graph_fn_loss_per_item(self,
                                key,
                                td_targets,
                                q_values_s,
                                actions,
                                importance_weights=None):
        """
        Args:
            td_targets (SingleDataOp): The already calculated TD-target terms (r + gamma maxa'Qt(s',a')
                OR for double Q: r + gamma Qt(s',argmaxa'(Q(s',a'))))

            q_values_s (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns
                when in s and taking different actions a.

            actions (SingleDataOp): The batch of actions that were actually taken in states s (from a memory).

            importance_weights (Optional[SingleDataOp]): If 'self.importance_weights' is True: The batch of weights to
                apply to the losses.

        Returns:
            SingleDataOp: The loss values vector (one single value for each batch item).
        """
        # Numpy backend primarily for testing purposes.
        if self.backend == "python" or get_backend() == "python":
            from rlgraph.utils.numpy import one_hot

            actions_one_hot = one_hot(
                actions, depth=self.flat_action_space[key].num_categories)
            q_s_a_values = np.sum(q_values_s * actions_one_hot, axis=-1)

            td_delta = td_targets - q_s_a_values

            if td_delta.ndim > 1:
                if self.importance_weights:
                    td_delta = np.mean(td_delta * importance_weights,
                                       axis=list(
                                           range(1, self.ranks_to_reduce + 1)))

                else:
                    td_delta = np.mean(td_delta,
                                       axis=list(
                                           range(1, self.ranks_to_reduce + 1)))

        elif get_backend() == "tf":
            # Q(s,a) -> Use the Q-value of the action actually taken before.
            one_hot = tf.one_hot(
                indices=actions,
                depth=self.flat_action_space[key].num_categories)
            q_s_a_values = tf.reduce_sum(input_tensor=(q_values_s * one_hot),
                                         axis=-1)

            # Calculate the TD-delta (target - current estimate).
            td_delta = td_targets - q_s_a_values

            # Reduce over the composite actions, if any.
            if get_rank(td_delta) > 1:
                td_delta = tf.reduce_mean(input_tensor=td_delta,
                                          axis=list(
                                              range(1,
                                                    self.ranks_to_reduce + 1)))

        elif get_backend() == "pytorch":
            # Add batch dim in case of single sample.
            if q_values_s.dim() == 1:
                q_values_s = q_values_s.unsqueeze(-1)
                actions = actions.unsqueeze(-1)
                if self.importance_weights:
                    importance_weights = importance_weights.unsqueeze(-1)

            # Q(s,a) -> Use the Q-value of the action actually taken before.
            one_hot = pytorch_one_hot(
                actions, depth=self.flat_action_space[key].num_categories)
            q_s_a_values = torch.sum((q_values_s * one_hot), -1)

            # Calculate the TD-delta (target - current estimate).
            td_delta = td_targets - q_s_a_values

            # Reduce over the composite actions, if any.
            if get_rank(td_delta) > 1:
                td_delta = torch.mean(td_delta,
                                      tuple(range(1,
                                                  self.ranks_to_reduce + 1)),
                                      keepdim=False)

        # Apply importance-weights from a prioritized replay to the loss.
        if self.importance_weights:
            return importance_weights * td_delta
        else:
            return td_delta
Ejemplo n.º 4
0
    def _graph_fn_get_td_targets(self,
                                 key,
                                 rewards,
                                 terminals,
                                 qt_values_sp,
                                 q_values_sp=None):
        """
        Args:
            rewards (SingleDataOp): The batch of rewards that we received after having taken a in s (from a memory).
            terminals (SingleDataOp): The batch of terminal signals that we received after having taken a in s
                (from a memory).
            qt_values_sp (SingleDataOp): The batch of Q-values representing the expected accumulated discounted
                returns (estimated by the target net) when in s' and taking different actions a'.
            q_values_sp (Optional[SingleDataOp]): If `self.double_q` is True: The batch of Q-values representing the
                expected accumulated discounted returns (estimated by the (main) policy net) when in s' and taking
                different actions a'.

        Returns:
            SingleDataOp: The target values vector.
        """
        qt_sp_ap_values = None

        # Numpy backend primarily for testing purposes.
        if self.backend == "python" or get_backend() == "python":
            from rlgraph.utils.numpy import one_hot
            if self.double_q:
                a_primes = np.argmax(q_values_sp, axis=-1)
                a_primes_one_hot = one_hot(
                    a_primes, depth=self.flat_action_space[key].num_categories)
                qt_sp_ap_values = np.sum(qt_values_sp * a_primes_one_hot,
                                         axis=-1)
            else:
                qt_sp_ap_values = np.max(qt_values_sp, axis=-1)

            for _ in range(qt_sp_ap_values.ndim - 1):
                rewards = np.expand_dims(rewards, axis=1)

            qt_sp_ap_values = np.where(terminals,
                                       np.zeros_like(qt_sp_ap_values),
                                       qt_sp_ap_values)

        elif get_backend() == "tf":
            # Make sure the target policy's outputs are treated as constant when calculating gradients.
            qt_values_sp = tf.stop_gradient(qt_values_sp)

            if self.double_q:
                # For double-Q, we no longer use the max(a')Qt(s'a') value.
                # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net!
                a_primes = tf.argmax(input=q_values_sp, axis=-1)

                # Now lookup Q(s'a') with the calculated a'.
                one_hot = tf.one_hot(
                    indices=a_primes,
                    depth=self.flat_action_space[key].num_categories)
                qt_sp_ap_values = tf.reduce_sum(input_tensor=(qt_values_sp *
                                                              one_hot),
                                                axis=-1)
            else:
                # Qt(s',a') -> Use the max(a') value (from the target network).
                qt_sp_ap_values = tf.reduce_max(input_tensor=qt_values_sp,
                                                axis=-1)

            # Make sure the rewards vector (batch) is broadcast correctly.
            for _ in range(get_rank(qt_sp_ap_values) - 1):
                rewards = tf.expand_dims(rewards, axis=1)

            # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'.
            # Note that in that case, the next_state (s') is not the correct next state and should be disregarded.
            # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis.
            qt_sp_ap_values = tf.where(condition=terminals,
                                       x=tf.zeros_like(qt_sp_ap_values),
                                       y=qt_sp_ap_values)

        elif get_backend() == "pytorch":
            if not isinstance(terminals, torch.ByteTensor):
                terminals = terminals.byte()
            # Add batch dim in case of single sample.
            if qt_values_sp.dim() == 1:
                rewards = rewards.unsqueeze(-1)
                terminals = terminals.unsqueeze(-1)
                q_values_sp = q_values_sp.unsqueeze(-1)
                qt_values_sp = qt_values_sp.unsqueeze(-1)

            # Make sure the target policy's outputs are treated as constant when calculating gradients.
            qt_values_sp = qt_values_sp.detach()
            if self.double_q:
                # For double-Q, we no longer use the max(a')Qt(s'a') value.
                # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net!
                a_primes = torch.argmax(q_values_sp, dim=-1, keepdim=True)

                # Now lookup Q(s'a') with the calculated a'.
                one_hot = pytorch_one_hot(
                    a_primes, depth=self.flat_action_space[key].num_categories)
                qt_sp_ap_values = torch.sum(qt_values_sp * one_hot.squeeze(),
                                            dim=-1)
            else:
                # Qt(s',a') -> Use the max(a') value (from the target network).
                qt_sp_ap_values = torch.max(qt_values_sp, -1)[0]

            # Make sure the rewards vector (batch) is broadcast correctly.
            for _ in range(get_rank(qt_sp_ap_values) - 1):
                rewards = torch.unsqueeze(rewards, dim=1)

            # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'.
            # Note that in that case, the next_state (s') is not the correct next state and should be disregarded.
            # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis.
            # torch.where cannot broadcast here, so tile and reshape to same shape.
            if qt_sp_ap_values.dim() > 1:
                num_tiles = np.prod(qt_sp_ap_values.shape[1:])
                terminals = pytorch_tile(terminals, num_tiles,
                                         -1).reshape(qt_sp_ap_values.shape)
            qt_sp_ap_values = torch.where(terminals,
                                          torch.zeros_like(qt_sp_ap_values),
                                          qt_sp_ap_values)
        td_targets = (rewards + (self.discount**self.n_step) * qt_sp_ap_values)
        return td_targets
Ejemplo n.º 5
0
    def _graph_fn_apply(self,
                        key,
                        preprocessing_inputs,
                        input_before_time_rank_folding=None):
        """
        Reshapes the input to the specified new shape.

        Args:
            preprocessing_inputs (SingleDataOp): The input to reshape.
            input_before_time_rank_folding (Optional[SingleDataOp]): The original input (before!) the time-rank had
                been folded (this was done in a different ReShape Component). Serves if `self.unfold_time_rank` is True
                to figure out the exact time-rank dimension to unfold.

        Returns:
            SingleDataOp: The reshaped input.
        """
        assert self.unfold_time_rank is False or input_before_time_rank_folding is not None

        if self.backend == "python" or get_backend() == "python":
            # Create a one-hot axis for the categories at the end?
            num_categories = self.get_num_categories(
                key, get_space_from_op(preprocessing_inputs))
            if num_categories and num_categories > 1:
                preprocessing_inputs = one_hot(preprocessing_inputs,
                                               depth=num_categories)

            if self.unfold_time_rank:
                new_shape = (-1, -1) + preprocessing_inputs.shape[1:]
            elif self.fold_time_rank:
                new_shape = (-1, ) + preprocessing_inputs.shape[2:]
            else:
                new_shape = self.get_preprocessed_space(
                    get_space_from_op(preprocessing_inputs)).get_shape(
                        with_batch_rank=-1, with_time_rank=-1)

            # Dynamic new shape inference:
            # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common
            # batch+time 0th rank, get these two dynamically.
            if len(preprocessing_inputs.shape
                   ) > 2 and new_shape[0] == -1 and new_shape[1] == -1:
                # Time rank unfolding. Get the time rank from original input.
                if self.unfold_time_rank is True:
                    original_shape = input_before_time_rank_folding.shape
                    new_shape = (original_shape[0],
                                 original_shape[1]) + new_shape[2:]
                # No time-rank unfolding, but we do have both batch- and time-rank.
                else:
                    input_shape = preprocessing_inputs.shape
                    # Batch and time rank stay as is.
                    new_shape = (input_shape[0],
                                 input_shape[1]) + new_shape[2:]

            return np.reshape(preprocessing_inputs, newshape=new_shape)

        elif get_backend() == "pytorch":
            # Create a one-hot axis for the categories at the end?
            num_categories = self.get_num_categories(
                key, get_space_from_op(preprocessing_inputs))
            if num_categories and num_categories > 1:
                preprocessing_inputs = pytorch_one_hot(preprocessing_inputs,
                                                       depth=num_categories)

            if self.unfold_time_rank:
                new_shape = (-1, -1) + preprocessing_inputs.shape[1:]
            elif self.fold_time_rank:
                new_shape = (-1, ) + preprocessing_inputs.shape[2:]
            else:
                new_shape = self.get_preprocessed_space(
                    get_space_from_op(preprocessing_inputs)).get_shape(
                        with_batch_rank=-1, with_time_rank=-1)

            # Dynamic new shape inference:
            # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common
            # batch+time 0th rank, get these two dynamically.
            if len(new_shape
                   ) > 2 and new_shape[0] == -1 and new_shape[1] == -1:
                # Time rank unfolding. Get the time rank from original input.
                if self.unfold_time_rank is True:
                    original_shape = input_before_time_rank_folding.shape
                    new_shape = (original_shape[0],
                                 original_shape[1]) + new_shape[2:]
                # No time-rank unfolding, but we do have both batch- and time-rank.
                else:
                    input_shape = preprocessing_inputs.shape
                    # Batch and time rank stay as is.
                    new_shape = (input_shape[0],
                                 input_shape[1]) + new_shape[2:]

            # print("Reshaping input of shape {} to new shape {} (flatten = {})".format(preprocessing_inputs.shape,
            #                                                                           new_shape, self.flatten))

            old_size = np.prod(list(preprocessing_inputs.shape))
            new_size = np.prod(new_shape)

            # The problem here is the following: Input has dim e.g. [4, 256, 1, 1]
            # -> If shape inference in spaces failed, output dim is not correct -> reshape will attempt
            # something like reshaping to [256].
            if self.flatten and preprocessing_inputs.dim() > 1:
                flattened_shape_without_batchrank = np.prod(
                    preprocessing_inputs.shape[1:])
                flattened_shape = (preprocessing_inputs.shape[0], ) + (
                    flattened_shape_without_batchrank, )
                return torch.reshape(preprocessing_inputs, flattened_shape)
            # If new shape does not fit into old shape, batch inference failed -> try to restore:
            # Equal except batch rank -> return as is:
            elif old_size != new_size:
                if tuple(preprocessing_inputs.shape[1:]) == new_shape:
                    return preprocessing_inputs
                else:
                    # Attempt to rescue reshape by combining new shape with batch dim.
                    full_new_shape = (
                        preprocessing_inputs.shape[0], ) + new_shape
                    return torch.reshape(preprocessing_inputs, full_new_shape)
            else:
                return torch.reshape(preprocessing_inputs, new_shape)

        elif get_backend() == "tf":
            # Create a one-hot axis for the categories at the end?
            space = get_space_from_op(preprocessing_inputs)
            num_categories = self.get_num_categories(key, space)
            if num_categories and num_categories > 1:
                preprocessing_inputs_ = tf.one_hot(preprocessing_inputs,
                                                   depth=num_categories,
                                                   axis=-1,
                                                   dtype="float32")
                if hasattr(preprocessing_inputs, "_batch_rank"):
                    preprocessing_inputs_._batch_rank = preprocessing_inputs._batch_rank
                if hasattr(preprocessing_inputs, "_time_rank"):
                    preprocessing_inputs_._time_rank = preprocessing_inputs._time_rank
                preprocessing_inputs = preprocessing_inputs_

            if self.unfold_time_rank:
                list_shape = preprocessing_inputs.shape.as_list()
                assert len(list_shape) == 1 or list_shape[1] is not None,\
                    "ERROR: Cannot unfold. `preprocessing_inputs` (with shape {}) " \
                    "already seems to be unfolded!".format(list_shape)
                new_shape = (-1, -1) + tuple(list_shape[1:])
            elif self.fold_time_rank:
                new_shape = (-1, ) + tuple(
                    preprocessing_inputs.shape.as_list()[2:])
            else:
                new_shape = self.get_preprocessed_space(
                    get_space_from_op(preprocessing_inputs)).get_shape(
                        with_batch_rank=-1, with_time_rank=-1)

            # Dynamic new shape inference:
            # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common
            # batch+time 0th rank, get these two dynamically.
            if len(new_shape
                   ) >= 2 and new_shape[0] == -1 and new_shape[1] == -1:
                # Time rank unfolding. Get the time rank from original input.
                if self.unfold_time_rank is True:
                    original_shape = tf.shape(input_before_time_rank_folding)
                    new_shape = (original_shape[0],
                                 original_shape[1]) + new_shape[2:]
                # No time-rank unfolding, but we do have both batch- and time-rank.
                else:
                    input_shape = tf.shape(preprocessing_inputs)
                    # Batch and time rank stay as is.
                    new_shape = (input_shape[0],
                                 input_shape[1]) + new_shape[2:]

            reshaped = tf.reshape(tensor=preprocessing_inputs,
                                  shape=new_shape,
                                  name="reshaped")

            # Have to place the time rank back in as unknown (for the auto Space inference).
            if type(self.unfold_time_rank) == int:
                # TODO: replace placeholder with default value by _batch_rank/_time_rank properties.
                return tf.placeholder_with_default(reshaped,
                                                   shape=(None, None) +
                                                   new_shape[2:])
            else:
                # TODO: add other cases of reshaping and fix batch/time rank hints.
                if self.fold_time_rank:
                    reshaped._batch_rank = 0
                elif self.unfold_time_rank:
                    reshaped._batch_rank = 1 if self.time_major is True else 0
                    reshaped._time_rank = 0 if self.time_major is True else 1
                else:
                    if space.has_batch_rank is True:
                        if space.time_major is False:
                            reshaped._batch_rank = 0
                        else:
                            reshaped._time_rank = 0
                            reshaped._batch_rank = 1
                    if space.has_time_rank is True:
                        reshaped._time_rank = 0 if space.time_major is True else 1

                return reshaped
Ejemplo n.º 6
0
    def _graph_fn_loss_per_item(self,
                                q_values_s,
                                actions,
                                rewards,
                                terminals,
                                qt_values_sp,
                                q_values_sp=None,
                                importance_weights=None):
        """
        Args:
            q_values_s (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns
                when in s and taking different actions a.
            actions (SingleDataOp): The batch of actions that were actually taken in states s (from a memory).
            rewards (SingleDataOp): The batch of rewards that we received after having taken a in s (from a memory).
            terminals (SingleDataOp): The batch of terminal signals that we received after having taken a in s
                (from a memory).
            qt_values_sp (SingleDataOp): The batch of Q-values representing the expected accumulated discounted
                returns (estimated by the target net) when in s' and taking different actions a'.
            q_values_sp (Optional[SingleDataOp]): If `self.double_q` is True: The batch of Q-values representing the
                expected accumulated discounted returns (estimated by the (main) policy net) when in s' and taking
                different actions a'.
            importance_weights (Optional[SingleDataOp]): If 'self.importance_weights' is True: The batch of weights to
                apply to the losses.

        Returns:
            SingleDataOp: The loss values vector (one single value for each batch item).
        """
        # Numpy backend primarily for testing purposes.
        if self.backend == "python" or get_backend() == "python":
            from rlgraph.utils.numpy import one_hot
            if self.double_q:
                a_primes = np.argmax(q_values_sp, axis=-1)
                a_primes_one_hot = one_hot(
                    a_primes, depth=self.action_space.num_categories)
                qt_sp_ap_values = np.sum(qt_values_sp * a_primes_one_hot,
                                         axis=-1)
            else:
                qt_sp_ap_values = np.max(qt_values_sp, axis=-1)

            for _ in range(qt_sp_ap_values.ndim - 1):
                rewards = np.expand_dims(rewards, axis=1)

            qt_sp_ap_values = np.where(terminals,
                                       np.zeros_like(qt_sp_ap_values),
                                       qt_sp_ap_values)

            actions_one_hot = one_hot(actions,
                                      depth=self.action_space.num_categories)
            q_s_a_values = np.sum(q_values_s * actions_one_hot, axis=-1)

            td_delta = (
                rewards +
                (self.discount**self.n_step) * qt_sp_ap_values) - q_s_a_values

            if td_delta.ndim > 1:
                if self.importance_weights:
                    td_delta = np.mean(td_delta * importance_weights,
                                       axis=list(
                                           range(1, self.ranks_to_reduce + 1)))

                else:
                    td_delta = np.mean(td_delta,
                                       axis=list(
                                           range(1, self.ranks_to_reduce + 1)))

            return self._apply_huber_loss_if_necessary(td_delta)
        elif get_backend() == "tf":
            # Make sure the target policy's outputs are treated as constant when calculating gradients.
            qt_values_sp = tf.stop_gradient(qt_values_sp)

            if self.double_q:
                # For double-Q, we no longer use the max(a')Qt(s'a') value.
                # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net!
                a_primes = tf.argmax(input=q_values_sp, axis=-1)

                # Now lookup Q(s'a') with the calculated a'.
                one_hot = tf.one_hot(indices=a_primes,
                                     depth=self.action_space.num_categories)
                qt_sp_ap_values = tf.reduce_sum(input_tensor=(qt_values_sp *
                                                              one_hot),
                                                axis=-1)
            else:
                # Qt(s',a') -> Use the max(a') value (from the target network).
                qt_sp_ap_values = tf.reduce_max(input_tensor=qt_values_sp,
                                                axis=-1)

            # Make sure the rewards vector (batch) is broadcast correctly.
            for _ in range(get_rank(qt_sp_ap_values) - 1):
                rewards = tf.expand_dims(rewards, axis=1)

            # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'.
            # Note that in that case, the next_state (s') is not the correct next state and should be disregarded.
            # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis.
            qt_sp_ap_values = tf.where(condition=terminals,
                                       x=tf.zeros_like(qt_sp_ap_values),
                                       y=qt_sp_ap_values)

            # Q(s,a) -> Use the Q-value of the action actually taken before.
            one_hot = tf.one_hot(indices=actions,
                                 depth=self.action_space.num_categories)
            q_s_a_values = tf.reduce_sum(input_tensor=(q_values_s * one_hot),
                                         axis=-1)

            # Calculate the TD-delta (target - current estimate).
            td_delta = (
                rewards +
                (self.discount**self.n_step) * qt_sp_ap_values) - q_s_a_values

            # Reduce over the composite actions, if any.
            if get_rank(td_delta) > 1:
                td_delta = tf.reduce_mean(input_tensor=td_delta,
                                          axis=list(
                                              range(1,
                                                    self.ranks_to_reduce + 1)))

            # Apply importance-weights from a prioritized replay to the loss.
            if self.importance_weights:
                return importance_weights * self._apply_huber_loss_if_necessary(
                    td_delta)
            else:
                return self._apply_huber_loss_if_necessary(td_delta)
        elif get_backend() == "pytorch":
            if not isinstance(terminals, torch.ByteTensor):
                terminals = terminals.byte()
            # Add batch dim in case of single sample.
            if q_values_s.dim() == 1:
                q_values_s = q_values_s.unsqueeze(-1)
                actions = actions.unsqueeze(-1)
                rewards = rewards.unsqueeze(-1)
                terminals = terminals.unsqueeze(-1)
                q_values_sp = q_values_sp.unsqueeze(-1)
                qt_values_sp = qt_values_sp.unsqueeze(-1)
                if self.importance_weights:
                    importance_weights = importance_weights.unsqueeze(-1)

            # Make sure the target policy's outputs are treated as constant when calculating gradients.
            qt_values_sp = qt_values_sp.detach()
            if self.double_q:
                # For double-Q, we no longer use the max(a')Qt(s'a') value.
                # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net!
                a_primes = torch.argmax(q_values_sp, dim=-1, keepdim=True)

                # Now lookup Q(s'a') with the calculated a'.
                one_hot = pytorch_one_hot(
                    a_primes, depth=self.action_space.num_categories)
                qt_sp_ap_values = torch.sum(qt_values_sp * one_hot, dim=-1)
            else:
                # Qt(s',a') -> Use the max(a') value (from the target network).
                qt_sp_ap_values = torch.max(qt_values_sp)

            # Make sure the rewards vector (batch) is broadcast correctly.
            for _ in range(get_rank(qt_sp_ap_values) - 1):
                rewards = torch.unsqueeze(rewards, dim=1)

            # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'.
            # Note that in that case, the next_state (s') is not the correct next state and should be disregarded.
            # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis.
            qt_sp_ap_values = torch.where(terminals,
                                          torch.zeros_like(qt_sp_ap_values),
                                          qt_sp_ap_values)
            # Q(s,a) -> Use the Q-value of the action actually taken before.
            one_hot = pytorch_one_hot(actions,
                                      depth=self.action_space.num_categories)
            q_s_a_values = torch.sum((q_values_s * one_hot), -1)

            # Calculate the TD-delta (target - current estimate).
            td_delta = (
                rewards +
                (self.discount**self.n_step) * qt_sp_ap_values) - q_s_a_values

            # Reduce over the composite actions, if any.
            if get_rank(td_delta) > 1:
                td_delta = pytorch_reduce_mean(
                    td_delta,
                    list(range(1, self.ranks_to_reduce + 1)),
                    keepdims=False)

            # Apply importance-weights from a prioritized replay to the loss.
            if self.importance_weights:
                return importance_weights * self._apply_huber_loss_if_necessary(
                    td_delta)
            else:
                return self._apply_huber_loss_if_necessary(td_delta)