def _graph_fn_loss_per_item(self, key, td_targets, q_values_s, actions, importance_weights=None): """ Args: td_targets (SingleDataOp): The already calculated TD-target terms (r + gamma maxa'Qt(s',a') OR for double Q: r + gamma Qt(s',argmaxa'(Q(s',a')))) q_values_s (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns when in s and taking different actions a. actions (SingleDataOp): The batch of actions that were actually taken in states s (from a memory). importance_weights (Optional[SingleDataOp]): If 'self.importance_weights' is True: The batch of weights to apply to the losses. Returns: SingleDataOp: The loss values vector (one single value for each batch item). """ # Numpy backend primarily for testing purposes. if self.backend == "python" or get_backend() == "python": from rlgraph.utils.numpy import one_hot actions_one_hot = one_hot( actions, depth=self.flat_action_space[key].num_categories) q_s_a_values = np.sum(q_values_s * actions_one_hot, axis=-1) td_delta = td_targets - q_s_a_values if td_delta.ndim > 1: if self.importance_weights: td_delta = np.mean(td_delta * importance_weights, axis=list( range(1, self.ranks_to_reduce + 1))) else: td_delta = np.mean(td_delta, axis=list( range(1, self.ranks_to_reduce + 1))) elif get_backend() == "tf": # Q(s,a) -> Use the Q-value of the action actually taken before. one_hot = tf.one_hot( indices=actions, depth=self.flat_action_space[key].num_categories) q_s_a_values = tf.reduce_sum(input_tensor=(q_values_s * one_hot), axis=-1) # Calculate the TD-delta (target - current estimate). td_delta = td_targets - q_s_a_values # Reduce over the composite actions, if any. if get_rank(td_delta) > 1: td_delta = tf.reduce_mean(input_tensor=td_delta, axis=list( range(1, self.ranks_to_reduce + 1))) elif get_backend() == "pytorch": # Add batch dim in case of single sample. if q_values_s.dim() == 1: q_values_s = q_values_s.unsqueeze(-1) actions = actions.unsqueeze(-1) if self.importance_weights: importance_weights = importance_weights.unsqueeze(-1) # Q(s,a) -> Use the Q-value of the action actually taken before. one_hot = pytorch_one_hot( actions, depth=self.flat_action_space[key].num_categories) q_s_a_values = torch.sum((q_values_s * one_hot), -1) # Calculate the TD-delta (target - current estimate). td_delta = td_targets - q_s_a_values # Reduce over the composite actions, if any. if get_rank(td_delta) > 1: td_delta = torch.mean(td_delta, tuple(range(1, self.ranks_to_reduce + 1)), keepdim=False) # Apply importance-weights from a prioritized replay to the loss. if self.importance_weights: return importance_weights * td_delta else: return td_delta
def _graph_fn_get_td_targets(self, key, rewards, terminals, qt_values_sp, q_values_sp=None): """ Args: rewards (SingleDataOp): The batch of rewards that we received after having taken a in s (from a memory). terminals (SingleDataOp): The batch of terminal signals that we received after having taken a in s (from a memory). qt_values_sp (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns (estimated by the target net) when in s' and taking different actions a'. q_values_sp (Optional[SingleDataOp]): If `self.double_q` is True: The batch of Q-values representing the expected accumulated discounted returns (estimated by the (main) policy net) when in s' and taking different actions a'. Returns: SingleDataOp: The target values vector. """ qt_sp_ap_values = None # Numpy backend primarily for testing purposes. if self.backend == "python" or get_backend() == "python": from rlgraph.utils.numpy import one_hot if self.double_q: a_primes = np.argmax(q_values_sp, axis=-1) a_primes_one_hot = one_hot( a_primes, depth=self.flat_action_space[key].num_categories) qt_sp_ap_values = np.sum(qt_values_sp * a_primes_one_hot, axis=-1) else: qt_sp_ap_values = np.max(qt_values_sp, axis=-1) for _ in range(qt_sp_ap_values.ndim - 1): rewards = np.expand_dims(rewards, axis=1) qt_sp_ap_values = np.where(terminals, np.zeros_like(qt_sp_ap_values), qt_sp_ap_values) elif get_backend() == "tf": # Make sure the target policy's outputs are treated as constant when calculating gradients. qt_values_sp = tf.stop_gradient(qt_values_sp) if self.double_q: # For double-Q, we no longer use the max(a')Qt(s'a') value. # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net! a_primes = tf.argmax(input=q_values_sp, axis=-1) # Now lookup Q(s'a') with the calculated a'. one_hot = tf.one_hot( indices=a_primes, depth=self.flat_action_space[key].num_categories) qt_sp_ap_values = tf.reduce_sum(input_tensor=(qt_values_sp * one_hot), axis=-1) else: # Qt(s',a') -> Use the max(a') value (from the target network). qt_sp_ap_values = tf.reduce_max(input_tensor=qt_values_sp, axis=-1) # Make sure the rewards vector (batch) is broadcast correctly. for _ in range(get_rank(qt_sp_ap_values) - 1): rewards = tf.expand_dims(rewards, axis=1) # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'. # Note that in that case, the next_state (s') is not the correct next state and should be disregarded. # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis. qt_sp_ap_values = tf.where(condition=terminals, x=tf.zeros_like(qt_sp_ap_values), y=qt_sp_ap_values) elif get_backend() == "pytorch": if not isinstance(terminals, torch.ByteTensor): terminals = terminals.byte() # Add batch dim in case of single sample. if qt_values_sp.dim() == 1: rewards = rewards.unsqueeze(-1) terminals = terminals.unsqueeze(-1) q_values_sp = q_values_sp.unsqueeze(-1) qt_values_sp = qt_values_sp.unsqueeze(-1) # Make sure the target policy's outputs are treated as constant when calculating gradients. qt_values_sp = qt_values_sp.detach() if self.double_q: # For double-Q, we no longer use the max(a')Qt(s'a') value. # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net! a_primes = torch.argmax(q_values_sp, dim=-1, keepdim=True) # Now lookup Q(s'a') with the calculated a'. one_hot = pytorch_one_hot( a_primes, depth=self.flat_action_space[key].num_categories) qt_sp_ap_values = torch.sum(qt_values_sp * one_hot.squeeze(), dim=-1) else: # Qt(s',a') -> Use the max(a') value (from the target network). qt_sp_ap_values = torch.max(qt_values_sp, -1)[0] # Make sure the rewards vector (batch) is broadcast correctly. for _ in range(get_rank(qt_sp_ap_values) - 1): rewards = torch.unsqueeze(rewards, dim=1) # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'. # Note that in that case, the next_state (s') is not the correct next state and should be disregarded. # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis. # torch.where cannot broadcast here, so tile and reshape to same shape. if qt_sp_ap_values.dim() > 1: num_tiles = np.prod(qt_sp_ap_values.shape[1:]) terminals = pytorch_tile(terminals, num_tiles, -1).reshape(qt_sp_ap_values.shape) qt_sp_ap_values = torch.where(terminals, torch.zeros_like(qt_sp_ap_values), qt_sp_ap_values) td_targets = (rewards + (self.discount**self.n_step) * qt_sp_ap_values) return td_targets
def _graph_fn_loss_per_item(self, key, td_targets, q_values_s, actions, expert_margins, importance_weights=None, apply_demo_loss=False): """ Args: td_targets (SingleDataOp): The already calculated TD-target terms (r + gamma maxa'Qt(s',a') OR for double Q: r + gamma Qt(s',argmaxa'(Q(s',a')))) q_values_s (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns when in s and taking different actions a. actions (SingleDataOp): The batch of actions that were actually taken in states s (from a memory). importance_weights (Optional[SingleDataOp]): If 'self.importance_weights' is True: The batch of weights to apply to the losses. apply_demo_loss (Optional[SingleDataOp]): If 'apply_demo_loss' is True: The large-margin loss is applied. Should be set to True when updating from demo data, False when updating from online data. expert_margins (SingleDataOp): The expert margin enforces a distance in Q-values between expert action and all other actions. Returns: SingleDataOp: The loss values vector (one single value for each batch item). """ if get_backend() == "tf": # Q(s,a) -> Use the Q-value of the action actually taken before. one_hot = tf.one_hot(indices=actions, depth=self.flat_action_space[key].num_categories) q_s_a_values = tf.reduce_sum(input_tensor=(q_values_s * one_hot), axis=-1) # Calculate the TD-delta (targets - current estimate). td_delta = td_targets - q_s_a_values # Calculate the demo-loss. # J_E(Q) = max_a([Q(s, a_taken) + l(s, a_expert, a_taken)] - Q(s, a_expert) mask = tf.ones_like(tensor=one_hot, dtype=tf.float32) action_mask = mask - one_hot # Margin mask: allow custom per-sample expert margins -> requires creating a margin matrix. # Instead of applying the same margin to all samples, users can pass a margin vector. # Broadcast to one hot shape expert_margins = tf.expand_dims(expert_margins, -1) expert_margins = tf.broadcast_to(input=expert_margins, shape=tf.shape(one_hot)) margin_mask = expert_margins - one_hot # margin_mask = tf.Print(margin_mask, [margin_mask], summarize=100, message="margin mask =") margin_val = action_mask * margin_mask loss_input = q_values_s + margin_val # Apply margin. def map_margins(x): element_margin = x[0] element_loss = x[1] # Positive margins: apply max. # Negative margins: apply min. return tf.cond( pred=tf.reduce_sum(element_margin) > 0, true_fn=lambda: tf.reduce_max(element_loss), false_fn=lambda: tf.reduce_min(element_loss), ) supervised_loss = tf.map_fn(map_margins, (margin_val, loss_input), dtype=tf.float32) # Subtract Q-values of action actually taken. supervised_delta = supervised_loss - q_s_a_values td_delta = tf.cond( pred=apply_demo_loss, true_fn=lambda: td_delta + self.supervised_weight * supervised_delta, false_fn=lambda: td_delta ) # Reduce over the composite actions, if any. if get_rank(td_delta) > 1: td_delta = tf.reduce_mean(input_tensor=td_delta, axis=list(range(1, self.ranks_to_reduce + 1))) # Apply importance-weights from a prioritized replay to the loss. if self.importance_weights: return importance_weights * td_delta else: return td_delta
def _graph_fn_apply(self, preprocessing_inputs): """ Gray-scales images of arbitrary rank. Normally, the images' rank is 3 (width/height/colors), but can also be: batch/width/height/colors, or any other. However, the last rank must be of size: len(self.weights). Args: preprocessing_inputs (tensor): Single image or a batch of images to be gray-scaled (last rank=n colors, where n=len(self.weights)). Returns: DataOp: The op for processing the images. """ # The reshaped weights used for the grayscale operation. if isinstance(preprocessing_inputs, list): preprocessing_inputs = np.asarray(preprocessing_inputs) images_shape = get_shape(preprocessing_inputs) assert images_shape[-1] == self.last_rank,\ "ERROR: Given image's shape ({}) does not match number of weights (last rank must be {})!".\ format(images_shape, self.last_rank) if self.backend == "python" or get_backend() == "python": if preprocessing_inputs.ndim == 4: grayscaled = [] for i in range_(len(preprocessing_inputs)): scaled = cv2.cvtColor(preprocessing_inputs[i], cv2.COLOR_RGB2GRAY) grayscaled.append(scaled) scaled_images = np.asarray(grayscaled) # Keep last dim. if self.keep_rank: scaled_images = scaled_images[:, :, :, np.newaxis] else: # Sample by sample. scaled_images = cv2.cvtColor(preprocessing_inputs, cv2.COLOR_RGB2GRAY) return scaled_images elif get_backend() == "pytorch": if len(preprocessing_inputs.shape) == 4: grayscaled = [] for i in range_(len(preprocessing_inputs)): scaled = cv2.cvtColor(preprocessing_inputs[i].numpy(), cv2.COLOR_RGB2GRAY) grayscaled.append(scaled) scaled_images = np.asarray(grayscaled) # Keep last dim. if self.keep_rank: scaled_images = scaled_images[:, :, :, np.newaxis] else: # Sample by sample. scaled_images = cv2.cvtColor(preprocessing_inputs.numpy(), cv2.COLOR_RGB2GRAY) return torch.tensor(scaled_images) elif get_backend() == "tf": weights_reshaped = np.reshape( self.weights, newshape=tuple([1] * (get_rank(preprocessing_inputs) - 1)) + (self.last_rank,) ) # Do we need to convert? # The dangerous thing is that multiplying an int tensor (image) with float weights results in an all # 0 tensor). if "int" in str(dtype_(preprocessing_inputs.dtype)): weighted = weights_reshaped * tf.cast(preprocessing_inputs, dtype=dtype_("float")) else: weighted = weights_reshaped * preprocessing_inputs reduced = tf.reduce_sum(weighted, axis=-1, keepdims=self.keep_rank) # Cast back to original dtype. if "int" in str(dtype_(preprocessing_inputs.dtype)): reduced = tf.cast(reduced, dtype=preprocessing_inputs.dtype) return reduced
def _graph_fn_loss_per_item(self, q_values_s, actions, rewards, terminals, qt_values_sp, q_values_sp=None, importance_weights=None): """ Args: q_values_s (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns when in s and taking different actions a. actions (SingleDataOp): The batch of actions that were actually taken in states s (from a memory). rewards (SingleDataOp): The batch of rewards that we received after having taken a in s (from a memory). terminals (SingleDataOp): The batch of terminal signals that we received after having taken a in s (from a memory). qt_values_sp (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns (estimated by the target net) when in s' and taking different actions a'. q_values_sp (Optional[SingleDataOp]): If `self.double_q` is True: The batch of Q-values representing the expected accumulated discounted returns (estimated by the (main) policy net) when in s' and taking different actions a'. importance_weights (Optional[SingleDataOp]): If 'self.importance_weights' is True: The batch of weights to apply to the losses. Returns: SingleDataOp: The loss values vector (one single value for each batch item). """ # Numpy backend primarily for testing purposes. if self.backend == "python" or get_backend() == "python": from rlgraph.utils.numpy import one_hot if self.double_q: a_primes = np.argmax(q_values_sp, axis=-1) a_primes_one_hot = one_hot( a_primes, depth=self.action_space.num_categories) qt_sp_ap_values = np.sum(qt_values_sp * a_primes_one_hot, axis=-1) else: qt_sp_ap_values = np.max(qt_values_sp, axis=-1) for _ in range(qt_sp_ap_values.ndim - 1): rewards = np.expand_dims(rewards, axis=1) qt_sp_ap_values = np.where(terminals, np.zeros_like(qt_sp_ap_values), qt_sp_ap_values) actions_one_hot = one_hot(actions, depth=self.action_space.num_categories) q_s_a_values = np.sum(q_values_s * actions_one_hot, axis=-1) td_delta = ( rewards + (self.discount**self.n_step) * qt_sp_ap_values) - q_s_a_values if td_delta.ndim > 1: if self.importance_weights: td_delta = np.mean(td_delta * importance_weights, axis=list( range(1, self.ranks_to_reduce + 1))) else: td_delta = np.mean(td_delta, axis=list( range(1, self.ranks_to_reduce + 1))) return self._apply_huber_loss_if_necessary(td_delta) elif get_backend() == "tf": # Make sure the target policy's outputs are treated as constant when calculating gradients. qt_values_sp = tf.stop_gradient(qt_values_sp) if self.double_q: # For double-Q, we no longer use the max(a')Qt(s'a') value. # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net! a_primes = tf.argmax(input=q_values_sp, axis=-1) # Now lookup Q(s'a') with the calculated a'. one_hot = tf.one_hot(indices=a_primes, depth=self.action_space.num_categories) qt_sp_ap_values = tf.reduce_sum(input_tensor=(qt_values_sp * one_hot), axis=-1) else: # Qt(s',a') -> Use the max(a') value (from the target network). qt_sp_ap_values = tf.reduce_max(input_tensor=qt_values_sp, axis=-1) # Make sure the rewards vector (batch) is broadcast correctly. for _ in range(get_rank(qt_sp_ap_values) - 1): rewards = tf.expand_dims(rewards, axis=1) # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'. # Note that in that case, the next_state (s') is not the correct next state and should be disregarded. # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis. qt_sp_ap_values = tf.where(condition=terminals, x=tf.zeros_like(qt_sp_ap_values), y=qt_sp_ap_values) # Q(s,a) -> Use the Q-value of the action actually taken before. one_hot = tf.one_hot(indices=actions, depth=self.action_space.num_categories) q_s_a_values = tf.reduce_sum(input_tensor=(q_values_s * one_hot), axis=-1) # Calculate the TD-delta (target - current estimate). td_delta = ( rewards + (self.discount**self.n_step) * qt_sp_ap_values) - q_s_a_values # Reduce over the composite actions, if any. if get_rank(td_delta) > 1: td_delta = tf.reduce_mean(input_tensor=td_delta, axis=list( range(1, self.ranks_to_reduce + 1))) # Apply importance-weights from a prioritized replay to the loss. if self.importance_weights: return importance_weights * self._apply_huber_loss_if_necessary( td_delta) else: return self._apply_huber_loss_if_necessary(td_delta) elif get_backend() == "pytorch": if not isinstance(terminals, torch.ByteTensor): terminals = terminals.byte() # Add batch dim in case of single sample. if q_values_s.dim() == 1: q_values_s = q_values_s.unsqueeze(-1) actions = actions.unsqueeze(-1) rewards = rewards.unsqueeze(-1) terminals = terminals.unsqueeze(-1) q_values_sp = q_values_sp.unsqueeze(-1) qt_values_sp = qt_values_sp.unsqueeze(-1) if self.importance_weights: importance_weights = importance_weights.unsqueeze(-1) # Make sure the target policy's outputs are treated as constant when calculating gradients. qt_values_sp = qt_values_sp.detach() if self.double_q: # For double-Q, we no longer use the max(a')Qt(s'a') value. # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net! a_primes = torch.argmax(q_values_sp, dim=-1, keepdim=True) # Now lookup Q(s'a') with the calculated a'. one_hot = pytorch_one_hot( a_primes, depth=self.action_space.num_categories) qt_sp_ap_values = torch.sum(qt_values_sp * one_hot, dim=-1) else: # Qt(s',a') -> Use the max(a') value (from the target network). qt_sp_ap_values = torch.max(qt_values_sp) # Make sure the rewards vector (batch) is broadcast correctly. for _ in range(get_rank(qt_sp_ap_values) - 1): rewards = torch.unsqueeze(rewards, dim=1) # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'. # Note that in that case, the next_state (s') is not the correct next state and should be disregarded. # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis. qt_sp_ap_values = torch.where(terminals, torch.zeros_like(qt_sp_ap_values), qt_sp_ap_values) # Q(s,a) -> Use the Q-value of the action actually taken before. one_hot = pytorch_one_hot(actions, depth=self.action_space.num_categories) q_s_a_values = torch.sum((q_values_s * one_hot), -1) # Calculate the TD-delta (target - current estimate). td_delta = ( rewards + (self.discount**self.n_step) * qt_sp_ap_values) - q_s_a_values # Reduce over the composite actions, if any. if get_rank(td_delta) > 1: td_delta = pytorch_reduce_mean( td_delta, list(range(1, self.ranks_to_reduce + 1)), keepdims=False) # Apply importance-weights from a prioritized replay to the loss. if self.importance_weights: return importance_weights * self._apply_huber_loss_if_necessary( td_delta) else: return self._apply_huber_loss_if_necessary(td_delta)