Exemplo n.º 1
0
    def _graph_fn_loss_per_item(self,
                                key,
                                td_targets,
                                q_values_s,
                                actions,
                                importance_weights=None):
        """
        Args:
            td_targets (SingleDataOp): The already calculated TD-target terms (r + gamma maxa'Qt(s',a')
                OR for double Q: r + gamma Qt(s',argmaxa'(Q(s',a'))))

            q_values_s (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns
                when in s and taking different actions a.

            actions (SingleDataOp): The batch of actions that were actually taken in states s (from a memory).

            importance_weights (Optional[SingleDataOp]): If 'self.importance_weights' is True: The batch of weights to
                apply to the losses.

        Returns:
            SingleDataOp: The loss values vector (one single value for each batch item).
        """
        # Numpy backend primarily for testing purposes.
        if self.backend == "python" or get_backend() == "python":
            from rlgraph.utils.numpy import one_hot

            actions_one_hot = one_hot(
                actions, depth=self.flat_action_space[key].num_categories)
            q_s_a_values = np.sum(q_values_s * actions_one_hot, axis=-1)

            td_delta = td_targets - q_s_a_values

            if td_delta.ndim > 1:
                if self.importance_weights:
                    td_delta = np.mean(td_delta * importance_weights,
                                       axis=list(
                                           range(1, self.ranks_to_reduce + 1)))

                else:
                    td_delta = np.mean(td_delta,
                                       axis=list(
                                           range(1, self.ranks_to_reduce + 1)))

        elif get_backend() == "tf":
            # Q(s,a) -> Use the Q-value of the action actually taken before.
            one_hot = tf.one_hot(
                indices=actions,
                depth=self.flat_action_space[key].num_categories)
            q_s_a_values = tf.reduce_sum(input_tensor=(q_values_s * one_hot),
                                         axis=-1)

            # Calculate the TD-delta (target - current estimate).
            td_delta = td_targets - q_s_a_values

            # Reduce over the composite actions, if any.
            if get_rank(td_delta) > 1:
                td_delta = tf.reduce_mean(input_tensor=td_delta,
                                          axis=list(
                                              range(1,
                                                    self.ranks_to_reduce + 1)))

        elif get_backend() == "pytorch":
            # Add batch dim in case of single sample.
            if q_values_s.dim() == 1:
                q_values_s = q_values_s.unsqueeze(-1)
                actions = actions.unsqueeze(-1)
                if self.importance_weights:
                    importance_weights = importance_weights.unsqueeze(-1)

            # Q(s,a) -> Use the Q-value of the action actually taken before.
            one_hot = pytorch_one_hot(
                actions, depth=self.flat_action_space[key].num_categories)
            q_s_a_values = torch.sum((q_values_s * one_hot), -1)

            # Calculate the TD-delta (target - current estimate).
            td_delta = td_targets - q_s_a_values

            # Reduce over the composite actions, if any.
            if get_rank(td_delta) > 1:
                td_delta = torch.mean(td_delta,
                                      tuple(range(1,
                                                  self.ranks_to_reduce + 1)),
                                      keepdim=False)

        # Apply importance-weights from a prioritized replay to the loss.
        if self.importance_weights:
            return importance_weights * td_delta
        else:
            return td_delta
Exemplo n.º 2
0
    def _graph_fn_get_td_targets(self,
                                 key,
                                 rewards,
                                 terminals,
                                 qt_values_sp,
                                 q_values_sp=None):
        """
        Args:
            rewards (SingleDataOp): The batch of rewards that we received after having taken a in s (from a memory).
            terminals (SingleDataOp): The batch of terminal signals that we received after having taken a in s
                (from a memory).
            qt_values_sp (SingleDataOp): The batch of Q-values representing the expected accumulated discounted
                returns (estimated by the target net) when in s' and taking different actions a'.
            q_values_sp (Optional[SingleDataOp]): If `self.double_q` is True: The batch of Q-values representing the
                expected accumulated discounted returns (estimated by the (main) policy net) when in s' and taking
                different actions a'.

        Returns:
            SingleDataOp: The target values vector.
        """
        qt_sp_ap_values = None

        # Numpy backend primarily for testing purposes.
        if self.backend == "python" or get_backend() == "python":
            from rlgraph.utils.numpy import one_hot
            if self.double_q:
                a_primes = np.argmax(q_values_sp, axis=-1)
                a_primes_one_hot = one_hot(
                    a_primes, depth=self.flat_action_space[key].num_categories)
                qt_sp_ap_values = np.sum(qt_values_sp * a_primes_one_hot,
                                         axis=-1)
            else:
                qt_sp_ap_values = np.max(qt_values_sp, axis=-1)

            for _ in range(qt_sp_ap_values.ndim - 1):
                rewards = np.expand_dims(rewards, axis=1)

            qt_sp_ap_values = np.where(terminals,
                                       np.zeros_like(qt_sp_ap_values),
                                       qt_sp_ap_values)

        elif get_backend() == "tf":
            # Make sure the target policy's outputs are treated as constant when calculating gradients.
            qt_values_sp = tf.stop_gradient(qt_values_sp)

            if self.double_q:
                # For double-Q, we no longer use the max(a')Qt(s'a') value.
                # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net!
                a_primes = tf.argmax(input=q_values_sp, axis=-1)

                # Now lookup Q(s'a') with the calculated a'.
                one_hot = tf.one_hot(
                    indices=a_primes,
                    depth=self.flat_action_space[key].num_categories)
                qt_sp_ap_values = tf.reduce_sum(input_tensor=(qt_values_sp *
                                                              one_hot),
                                                axis=-1)
            else:
                # Qt(s',a') -> Use the max(a') value (from the target network).
                qt_sp_ap_values = tf.reduce_max(input_tensor=qt_values_sp,
                                                axis=-1)

            # Make sure the rewards vector (batch) is broadcast correctly.
            for _ in range(get_rank(qt_sp_ap_values) - 1):
                rewards = tf.expand_dims(rewards, axis=1)

            # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'.
            # Note that in that case, the next_state (s') is not the correct next state and should be disregarded.
            # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis.
            qt_sp_ap_values = tf.where(condition=terminals,
                                       x=tf.zeros_like(qt_sp_ap_values),
                                       y=qt_sp_ap_values)

        elif get_backend() == "pytorch":
            if not isinstance(terminals, torch.ByteTensor):
                terminals = terminals.byte()
            # Add batch dim in case of single sample.
            if qt_values_sp.dim() == 1:
                rewards = rewards.unsqueeze(-1)
                terminals = terminals.unsqueeze(-1)
                q_values_sp = q_values_sp.unsqueeze(-1)
                qt_values_sp = qt_values_sp.unsqueeze(-1)

            # Make sure the target policy's outputs are treated as constant when calculating gradients.
            qt_values_sp = qt_values_sp.detach()
            if self.double_q:
                # For double-Q, we no longer use the max(a')Qt(s'a') value.
                # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net!
                a_primes = torch.argmax(q_values_sp, dim=-1, keepdim=True)

                # Now lookup Q(s'a') with the calculated a'.
                one_hot = pytorch_one_hot(
                    a_primes, depth=self.flat_action_space[key].num_categories)
                qt_sp_ap_values = torch.sum(qt_values_sp * one_hot.squeeze(),
                                            dim=-1)
            else:
                # Qt(s',a') -> Use the max(a') value (from the target network).
                qt_sp_ap_values = torch.max(qt_values_sp, -1)[0]

            # Make sure the rewards vector (batch) is broadcast correctly.
            for _ in range(get_rank(qt_sp_ap_values) - 1):
                rewards = torch.unsqueeze(rewards, dim=1)

            # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'.
            # Note that in that case, the next_state (s') is not the correct next state and should be disregarded.
            # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis.
            # torch.where cannot broadcast here, so tile and reshape to same shape.
            if qt_sp_ap_values.dim() > 1:
                num_tiles = np.prod(qt_sp_ap_values.shape[1:])
                terminals = pytorch_tile(terminals, num_tiles,
                                         -1).reshape(qt_sp_ap_values.shape)
            qt_sp_ap_values = torch.where(terminals,
                                          torch.zeros_like(qt_sp_ap_values),
                                          qt_sp_ap_values)
        td_targets = (rewards + (self.discount**self.n_step) * qt_sp_ap_values)
        return td_targets
Exemplo n.º 3
0
    def _graph_fn_loss_per_item(self, key, td_targets, q_values_s, actions, expert_margins,
                                importance_weights=None, apply_demo_loss=False):
        """
        Args:
            td_targets (SingleDataOp): The already calculated TD-target terms (r + gamma maxa'Qt(s',a')
                OR for double Q: r + gamma Qt(s',argmaxa'(Q(s',a'))))
            q_values_s (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns
                when in s and taking different actions a.
            actions (SingleDataOp): The batch of actions that were actually taken in states s (from a memory).
            importance_weights (Optional[SingleDataOp]): If 'self.importance_weights' is True: The batch of weights to
                apply to the losses.
            apply_demo_loss (Optional[SingleDataOp]): If 'apply_demo_loss' is True: The large-margin loss is applied.
                Should be set to True when updating from demo data, False when updating from online data.
            expert_margins (SingleDataOp): The expert margin enforces a distance in Q-values between expert action and
                all other actions.
        Returns:
            SingleDataOp: The loss values vector (one single value for each batch item).
        """
        if get_backend() == "tf":
            # Q(s,a) -> Use the Q-value of the action actually taken before.
            one_hot = tf.one_hot(indices=actions, depth=self.flat_action_space[key].num_categories)
            q_s_a_values = tf.reduce_sum(input_tensor=(q_values_s * one_hot), axis=-1)

            # Calculate the TD-delta (targets - current estimate).
            td_delta = td_targets - q_s_a_values

            # Calculate the demo-loss.
            #  J_E(Q) = max_a([Q(s, a_taken) + l(s, a_expert, a_taken)] - Q(s, a_expert)
            mask = tf.ones_like(tensor=one_hot, dtype=tf.float32)
            action_mask = mask - one_hot

            # Margin mask: allow custom per-sample expert margins -> requires creating a margin matrix.
            # Instead of applying the same margin to all samples, users can pass a margin vector.
            # Broadcast to one hot shape
            expert_margins = tf.expand_dims(expert_margins, -1)
            expert_margins = tf.broadcast_to(input=expert_margins, shape=tf.shape(one_hot))
            margin_mask = expert_margins - one_hot

            # margin_mask = tf.Print(margin_mask, [margin_mask], summarize=100, message="margin mask =")
            margin_val = action_mask * margin_mask
            loss_input = q_values_s + margin_val

            # Apply margin.
            def map_margins(x):
                element_margin = x[0]
                element_loss = x[1]
                # Positive margins: apply max.
                # Negative margins: apply min.
                return tf.cond(
                    pred=tf.reduce_sum(element_margin) > 0,
                    true_fn=lambda: tf.reduce_max(element_loss),
                    false_fn=lambda: tf.reduce_min(element_loss),
                )
            supervised_loss = tf.map_fn(map_margins, (margin_val, loss_input), dtype=tf.float32)

            # Subtract Q-values of action actually taken.
            supervised_delta = supervised_loss - q_s_a_values
            td_delta = tf.cond(
                pred=apply_demo_loss,
                true_fn=lambda: td_delta + self.supervised_weight * supervised_delta,
                false_fn=lambda: td_delta
            )

            # Reduce over the composite actions, if any.
            if get_rank(td_delta) > 1:
                td_delta = tf.reduce_mean(input_tensor=td_delta, axis=list(range(1, self.ranks_to_reduce + 1)))

            # Apply importance-weights from a prioritized replay to the loss.
            if self.importance_weights:
                return importance_weights * td_delta
            else:
                return td_delta
Exemplo n.º 4
0
    def _graph_fn_apply(self, preprocessing_inputs):
        """
        Gray-scales images of arbitrary rank.
        Normally, the images' rank is 3 (width/height/colors), but can also be: batch/width/height/colors, or any other.
        However, the last rank must be of size: len(self.weights).

        Args:
            preprocessing_inputs (tensor): Single image or a batch of images to be gray-scaled (last rank=n colors, where
                n=len(self.weights)).

        Returns:
            DataOp: The op for processing the images.
        """
        # The reshaped weights used for the grayscale operation.
        if isinstance(preprocessing_inputs, list):
            preprocessing_inputs = np.asarray(preprocessing_inputs)
        images_shape = get_shape(preprocessing_inputs)
        assert images_shape[-1] == self.last_rank,\
            "ERROR: Given image's shape ({}) does not match number of weights (last rank must be {})!".\
            format(images_shape, self.last_rank)
        if self.backend == "python" or get_backend() == "python":
            if preprocessing_inputs.ndim == 4:
                grayscaled = []
                for i in range_(len(preprocessing_inputs)):
                    scaled = cv2.cvtColor(preprocessing_inputs[i], cv2.COLOR_RGB2GRAY)
                    grayscaled.append(scaled)
                scaled_images = np.asarray(grayscaled)

                # Keep last dim.
                if self.keep_rank:
                    scaled_images = scaled_images[:, :, :, np.newaxis]
            else:
                # Sample by sample.
                scaled_images = cv2.cvtColor(preprocessing_inputs, cv2.COLOR_RGB2GRAY)

            return scaled_images
        elif get_backend() == "pytorch":
            if len(preprocessing_inputs.shape) == 4:
                grayscaled = []
                for i in range_(len(preprocessing_inputs)):
                    scaled = cv2.cvtColor(preprocessing_inputs[i].numpy(), cv2.COLOR_RGB2GRAY)
                    grayscaled.append(scaled)
                scaled_images = np.asarray(grayscaled)
                # Keep last dim.
                if self.keep_rank:
                    scaled_images = scaled_images[:, :, :, np.newaxis]
            else:
                # Sample by sample.
                scaled_images = cv2.cvtColor(preprocessing_inputs.numpy(), cv2.COLOR_RGB2GRAY)
            return torch.tensor(scaled_images)
        elif get_backend() == "tf":
            weights_reshaped = np.reshape(
                self.weights, newshape=tuple([1] * (get_rank(preprocessing_inputs) - 1)) + (self.last_rank,)
            )

            # Do we need to convert?
            # The dangerous thing is that multiplying an int tensor (image) with float weights results in an all
            # 0 tensor).
            if "int" in str(dtype_(preprocessing_inputs.dtype)):
                weighted = weights_reshaped * tf.cast(preprocessing_inputs, dtype=dtype_("float"))
            else:
                weighted = weights_reshaped * preprocessing_inputs

            reduced = tf.reduce_sum(weighted, axis=-1, keepdims=self.keep_rank)

            # Cast back to original dtype.
            if "int" in str(dtype_(preprocessing_inputs.dtype)):
                reduced = tf.cast(reduced, dtype=preprocessing_inputs.dtype)

            return reduced
Exemplo n.º 5
0
    def _graph_fn_loss_per_item(self,
                                q_values_s,
                                actions,
                                rewards,
                                terminals,
                                qt_values_sp,
                                q_values_sp=None,
                                importance_weights=None):
        """
        Args:
            q_values_s (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns
                when in s and taking different actions a.
            actions (SingleDataOp): The batch of actions that were actually taken in states s (from a memory).
            rewards (SingleDataOp): The batch of rewards that we received after having taken a in s (from a memory).
            terminals (SingleDataOp): The batch of terminal signals that we received after having taken a in s
                (from a memory).
            qt_values_sp (SingleDataOp): The batch of Q-values representing the expected accumulated discounted
                returns (estimated by the target net) when in s' and taking different actions a'.
            q_values_sp (Optional[SingleDataOp]): If `self.double_q` is True: The batch of Q-values representing the
                expected accumulated discounted returns (estimated by the (main) policy net) when in s' and taking
                different actions a'.
            importance_weights (Optional[SingleDataOp]): If 'self.importance_weights' is True: The batch of weights to
                apply to the losses.

        Returns:
            SingleDataOp: The loss values vector (one single value for each batch item).
        """
        # Numpy backend primarily for testing purposes.
        if self.backend == "python" or get_backend() == "python":
            from rlgraph.utils.numpy import one_hot
            if self.double_q:
                a_primes = np.argmax(q_values_sp, axis=-1)
                a_primes_one_hot = one_hot(
                    a_primes, depth=self.action_space.num_categories)
                qt_sp_ap_values = np.sum(qt_values_sp * a_primes_one_hot,
                                         axis=-1)
            else:
                qt_sp_ap_values = np.max(qt_values_sp, axis=-1)

            for _ in range(qt_sp_ap_values.ndim - 1):
                rewards = np.expand_dims(rewards, axis=1)

            qt_sp_ap_values = np.where(terminals,
                                       np.zeros_like(qt_sp_ap_values),
                                       qt_sp_ap_values)

            actions_one_hot = one_hot(actions,
                                      depth=self.action_space.num_categories)
            q_s_a_values = np.sum(q_values_s * actions_one_hot, axis=-1)

            td_delta = (
                rewards +
                (self.discount**self.n_step) * qt_sp_ap_values) - q_s_a_values

            if td_delta.ndim > 1:
                if self.importance_weights:
                    td_delta = np.mean(td_delta * importance_weights,
                                       axis=list(
                                           range(1, self.ranks_to_reduce + 1)))

                else:
                    td_delta = np.mean(td_delta,
                                       axis=list(
                                           range(1, self.ranks_to_reduce + 1)))

            return self._apply_huber_loss_if_necessary(td_delta)
        elif get_backend() == "tf":
            # Make sure the target policy's outputs are treated as constant when calculating gradients.
            qt_values_sp = tf.stop_gradient(qt_values_sp)

            if self.double_q:
                # For double-Q, we no longer use the max(a')Qt(s'a') value.
                # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net!
                a_primes = tf.argmax(input=q_values_sp, axis=-1)

                # Now lookup Q(s'a') with the calculated a'.
                one_hot = tf.one_hot(indices=a_primes,
                                     depth=self.action_space.num_categories)
                qt_sp_ap_values = tf.reduce_sum(input_tensor=(qt_values_sp *
                                                              one_hot),
                                                axis=-1)
            else:
                # Qt(s',a') -> Use the max(a') value (from the target network).
                qt_sp_ap_values = tf.reduce_max(input_tensor=qt_values_sp,
                                                axis=-1)

            # Make sure the rewards vector (batch) is broadcast correctly.
            for _ in range(get_rank(qt_sp_ap_values) - 1):
                rewards = tf.expand_dims(rewards, axis=1)

            # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'.
            # Note that in that case, the next_state (s') is not the correct next state and should be disregarded.
            # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis.
            qt_sp_ap_values = tf.where(condition=terminals,
                                       x=tf.zeros_like(qt_sp_ap_values),
                                       y=qt_sp_ap_values)

            # Q(s,a) -> Use the Q-value of the action actually taken before.
            one_hot = tf.one_hot(indices=actions,
                                 depth=self.action_space.num_categories)
            q_s_a_values = tf.reduce_sum(input_tensor=(q_values_s * one_hot),
                                         axis=-1)

            # Calculate the TD-delta (target - current estimate).
            td_delta = (
                rewards +
                (self.discount**self.n_step) * qt_sp_ap_values) - q_s_a_values

            # Reduce over the composite actions, if any.
            if get_rank(td_delta) > 1:
                td_delta = tf.reduce_mean(input_tensor=td_delta,
                                          axis=list(
                                              range(1,
                                                    self.ranks_to_reduce + 1)))

            # Apply importance-weights from a prioritized replay to the loss.
            if self.importance_weights:
                return importance_weights * self._apply_huber_loss_if_necessary(
                    td_delta)
            else:
                return self._apply_huber_loss_if_necessary(td_delta)
        elif get_backend() == "pytorch":
            if not isinstance(terminals, torch.ByteTensor):
                terminals = terminals.byte()
            # Add batch dim in case of single sample.
            if q_values_s.dim() == 1:
                q_values_s = q_values_s.unsqueeze(-1)
                actions = actions.unsqueeze(-1)
                rewards = rewards.unsqueeze(-1)
                terminals = terminals.unsqueeze(-1)
                q_values_sp = q_values_sp.unsqueeze(-1)
                qt_values_sp = qt_values_sp.unsqueeze(-1)
                if self.importance_weights:
                    importance_weights = importance_weights.unsqueeze(-1)

            # Make sure the target policy's outputs are treated as constant when calculating gradients.
            qt_values_sp = qt_values_sp.detach()
            if self.double_q:
                # For double-Q, we no longer use the max(a')Qt(s'a') value.
                # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net!
                a_primes = torch.argmax(q_values_sp, dim=-1, keepdim=True)

                # Now lookup Q(s'a') with the calculated a'.
                one_hot = pytorch_one_hot(
                    a_primes, depth=self.action_space.num_categories)
                qt_sp_ap_values = torch.sum(qt_values_sp * one_hot, dim=-1)
            else:
                # Qt(s',a') -> Use the max(a') value (from the target network).
                qt_sp_ap_values = torch.max(qt_values_sp)

            # Make sure the rewards vector (batch) is broadcast correctly.
            for _ in range(get_rank(qt_sp_ap_values) - 1):
                rewards = torch.unsqueeze(rewards, dim=1)

            # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'.
            # Note that in that case, the next_state (s') is not the correct next state and should be disregarded.
            # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis.
            qt_sp_ap_values = torch.where(terminals,
                                          torch.zeros_like(qt_sp_ap_values),
                                          qt_sp_ap_values)
            # Q(s,a) -> Use the Q-value of the action actually taken before.
            one_hot = pytorch_one_hot(actions,
                                      depth=self.action_space.num_categories)
            q_s_a_values = torch.sum((q_values_s * one_hot), -1)

            # Calculate the TD-delta (target - current estimate).
            td_delta = (
                rewards +
                (self.discount**self.n_step) * qt_sp_ap_values) - q_s_a_values

            # Reduce over the composite actions, if any.
            if get_rank(td_delta) > 1:
                td_delta = pytorch_reduce_mean(
                    td_delta,
                    list(range(1, self.ranks_to_reduce + 1)),
                    keepdims=False)

            # Apply importance-weights from a prioritized replay to the loss.
            if self.importance_weights:
                return importance_weights * self._apply_huber_loss_if_necessary(
                    td_delta)
            else:
                return self._apply_huber_loss_if_necessary(td_delta)