def test_one_hot(self): """ Tests a torch one hot function. """ if get_backend() == "pytorch": # Flat action array. inputs = torch.tensor([0, 1], dtype=torch.int32) one_hot = pytorch_one_hot(inputs, depth=2) expected = torch.tensor([[1., 0.], [0., 1.]]) recursive_assert_almost_equal(one_hot, expected) # Container space. inputs = torch.tensor([[0, 3, 2], [1, 2, 0]], dtype=torch.int32) one_hot = pytorch_one_hot(inputs, depth=4) expected = torch.tensor( [[[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]], [[0, 1, 0, 0], [0, 0, 1, 0], [ 1, 0, 0, 0, ]]], dtype=torch.int32) recursive_assert_almost_equal(one_hot, expected)
def _graph_fn_apply(self, key, preprocessing_inputs, input_before_time_rank_folding=None): """ Reshapes the input to the specified new shape. Args: preprocessing_inputs (SingleDataOp): The input to reshape. input_before_time_rank_folding (Optional[SingleDataOp]): The original input (before!) the time-rank had been folded (this was done in a different ReShape Component). Serves if `self.unfold_time_rank` is True to figure out the exact time-rank dimension to unfold. Returns: SingleDataOp: The reshaped input. """ assert self.unfold_time_rank is False or input_before_time_rank_folding is not None #preprocessing_inputs = tf.Print(preprocessing_inputs, [tf.shape(preprocessing_inputs)], summarize=1000, # message="input shape for {} (key={}): {}".format(preprocessing_inputs.name, key, self.scope)) if self.backend == "python" or get_backend() == "python": # Create a one-hot axis for the categories at the end? if self.num_categories.get(key, 0) > 1: preprocessing_inputs = one_hot(preprocessing_inputs, depth=self.num_categories[key]) new_shape = self.output_spaces[key].get_shape( with_batch_rank=-1, with_time_rank=-1, time_major=self.time_major ) # Dynamic new shape inference: # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common # batch+time 0th rank, get these two dynamically. # Note: We may still flip the two, if input space has a different `time_major` than output space. if len(new_shape) > 2 and new_shape[0] == -1 and new_shape[1] == -1: # Time rank unfolding. Get the time rank from original input. if self.unfold_time_rank is True: original_shape = input_before_time_rank_folding.shape new_shape = (original_shape[0], original_shape[1]) + new_shape[2:] # No time-rank unfolding, but we do have both batch- and time-rank. else: input_shape = preprocessing_inputs.shape # Batch and time rank stay as is. if self.time_major is None or self.time_major is self.in_space_time_majors[key]: new_shape = (input_shape[0], input_shape[1]) + new_shape[2:] # Batch and time rank need to be flipped around: Do a transpose. else: preprocessing_inputs = np.transpose(preprocessing_inputs, axes=(1, 0) + input_shape[2:]) new_shape = (input_shape[1], input_shape[0]) + new_shape[2:] return np.reshape(preprocessing_inputs, newshape=new_shape) elif get_backend() == "pytorch": # Create a one-hot axis for the categories at the end? if self.num_categories.get(key, 0) > 1: preprocessing_inputs = pytorch_one_hot(preprocessing_inputs, depth=self.num_categories[key]) new_shape = self.output_spaces[key].get_shape( with_batch_rank=-1, with_time_rank=-1, time_major=self.time_major ) # Dynamic new shape inference: # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common # batch+time 0th rank, get these two dynamically. # Note: We may still flip the two, if input space has a different `time_major` than output space. if len(new_shape) > 2 and new_shape[0] == -1 and new_shape[1] == -1: # Time rank unfolding. Get the time rank from original input. if self.unfold_time_rank is True: original_shape = input_before_time_rank_folding.shape new_shape = (original_shape[0], original_shape[1]) + new_shape[2:] # No time-rank unfolding, but we do have both batch- and time-rank. else: input_shape = preprocessing_inputs.shape # Batch and time rank stay as is. if self.time_major is None or self.time_major is self.in_space_time_majors[key]: new_shape = (input_shape[0], input_shape[1]) + new_shape[2:] # Batch and time rank need to be flipped around: Do a transpose. else: preprocessing_inputs = torch.transpose(preprocessing_inputs, (1, 0) + input_shape[2:]) new_shape = (input_shape[1], input_shape[0]) + new_shape[2:] # print("Reshaping input of shape {} to new shape {} ".format(preprocessing_inputs.shape, new_shape)) # The problem here is the following: Input has dim e.g. [4, 256, 1, 1] # -> If shape inference in spaces failed, output dim is not correct -> reshape will attempt # something like reshaping to [256]. if self.flatten or (preprocessing_inputs.size(0) > 1 and preprocessing_inputs.dim() > 1): return preprocessing_inputs.squeeze() else: return torch.reshape(preprocessing_inputs, new_shape) elif get_backend() == "tf": # Create a one-hot axis for the categories at the end? if self.num_categories.get(key, 0) > 1: preprocessing_inputs = tf.one_hot( preprocessing_inputs, depth=self.num_categories[key], axis=-1, dtype="float32" ) new_shape = self.output_spaces[key].get_shape( with_batch_rank=-1, with_time_rank=-1, time_major=self.time_major ) # Dynamic new shape inference: # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common # batch+time 0th rank, get these two dynamically. # Note: We may still flip the two, if input space has a different `time_major` than output space. flip_after_reshape = False if len(new_shape) >= 2 and new_shape[0] == -1 and new_shape[1] == -1: # Time rank unfolding. Get the time rank from original input (and maybe flip). if self.unfold_time_rank is True: original_shape = tf.shape(input_before_time_rank_folding) new_shape = (original_shape[0], original_shape[1]) + new_shape[2:] flip_after_reshape = self.flip_batch_and_time_rank # No time-rank unfolding, but we do have both batch- and time-rank. else: input_shape = tf.shape(preprocessing_inputs) # Batch and time rank stay as is. if self.time_major is None or self.time_major is self.in_space_time_majors[key]: new_shape = (input_shape[0], input_shape[1]) + new_shape[2:] # Batch and time rank need to be flipped around: Do a transpose. else: assert self.flip_batch_and_time_rank is True preprocessing_inputs = tf.transpose( preprocessing_inputs, perm=(1, 0) + tuple(i for i in range( 2, input_shape.shape.as_list()[0] )), name="transpose-flip-batch-time-ranks" ) new_shape = (input_shape[1], input_shape[0]) + new_shape[2:] reshaped = tf.reshape(tensor=preprocessing_inputs, shape=new_shape, name="reshaped") if flip_after_reshape and self.flip_batch_and_time_rank: reshaped = tf.transpose(reshaped, (1, 0) + tuple(i for i in range(2, len(new_shape))), name="transpose-flip-batch-time-ranks-after-reshape") #reshaped = tf.Print(reshaped, [tf.shape(reshaped)], summarize=1000, # message="output shape for {} (key={}): {}".format(reshaped, key, self.scope)) # Have to place the time rank back in as unknown (for the auto Space inference). if type(self.unfold_time_rank) == int: # TODO: replace placeholder with default value by _batch_rank/_time_rank properties. return tf.placeholder_with_default(reshaped, shape=(None, None) + new_shape[2:]) else: # TODO: add other cases of reshaping and fix batch/time rank hints. if self.fold_time_rank: reshaped._batch_rank = 0 elif self.unfold_time_rank or self.flip_batch_and_time_rank: reshaped._batch_rank = 0 if self.time_major is False else 1 reshaped._time_rank = 0 if self.time_major is True else 1 return reshaped
def _graph_fn_loss_per_item(self, key, td_targets, q_values_s, actions, importance_weights=None): """ Args: td_targets (SingleDataOp): The already calculated TD-target terms (r + gamma maxa'Qt(s',a') OR for double Q: r + gamma Qt(s',argmaxa'(Q(s',a')))) q_values_s (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns when in s and taking different actions a. actions (SingleDataOp): The batch of actions that were actually taken in states s (from a memory). importance_weights (Optional[SingleDataOp]): If 'self.importance_weights' is True: The batch of weights to apply to the losses. Returns: SingleDataOp: The loss values vector (one single value for each batch item). """ # Numpy backend primarily for testing purposes. if self.backend == "python" or get_backend() == "python": from rlgraph.utils.numpy import one_hot actions_one_hot = one_hot( actions, depth=self.flat_action_space[key].num_categories) q_s_a_values = np.sum(q_values_s * actions_one_hot, axis=-1) td_delta = td_targets - q_s_a_values if td_delta.ndim > 1: if self.importance_weights: td_delta = np.mean(td_delta * importance_weights, axis=list( range(1, self.ranks_to_reduce + 1))) else: td_delta = np.mean(td_delta, axis=list( range(1, self.ranks_to_reduce + 1))) elif get_backend() == "tf": # Q(s,a) -> Use the Q-value of the action actually taken before. one_hot = tf.one_hot( indices=actions, depth=self.flat_action_space[key].num_categories) q_s_a_values = tf.reduce_sum(input_tensor=(q_values_s * one_hot), axis=-1) # Calculate the TD-delta (target - current estimate). td_delta = td_targets - q_s_a_values # Reduce over the composite actions, if any. if get_rank(td_delta) > 1: td_delta = tf.reduce_mean(input_tensor=td_delta, axis=list( range(1, self.ranks_to_reduce + 1))) elif get_backend() == "pytorch": # Add batch dim in case of single sample. if q_values_s.dim() == 1: q_values_s = q_values_s.unsqueeze(-1) actions = actions.unsqueeze(-1) if self.importance_weights: importance_weights = importance_weights.unsqueeze(-1) # Q(s,a) -> Use the Q-value of the action actually taken before. one_hot = pytorch_one_hot( actions, depth=self.flat_action_space[key].num_categories) q_s_a_values = torch.sum((q_values_s * one_hot), -1) # Calculate the TD-delta (target - current estimate). td_delta = td_targets - q_s_a_values # Reduce over the composite actions, if any. if get_rank(td_delta) > 1: td_delta = torch.mean(td_delta, tuple(range(1, self.ranks_to_reduce + 1)), keepdim=False) # Apply importance-weights from a prioritized replay to the loss. if self.importance_weights: return importance_weights * td_delta else: return td_delta
def _graph_fn_get_td_targets(self, key, rewards, terminals, qt_values_sp, q_values_sp=None): """ Args: rewards (SingleDataOp): The batch of rewards that we received after having taken a in s (from a memory). terminals (SingleDataOp): The batch of terminal signals that we received after having taken a in s (from a memory). qt_values_sp (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns (estimated by the target net) when in s' and taking different actions a'. q_values_sp (Optional[SingleDataOp]): If `self.double_q` is True: The batch of Q-values representing the expected accumulated discounted returns (estimated by the (main) policy net) when in s' and taking different actions a'. Returns: SingleDataOp: The target values vector. """ qt_sp_ap_values = None # Numpy backend primarily for testing purposes. if self.backend == "python" or get_backend() == "python": from rlgraph.utils.numpy import one_hot if self.double_q: a_primes = np.argmax(q_values_sp, axis=-1) a_primes_one_hot = one_hot( a_primes, depth=self.flat_action_space[key].num_categories) qt_sp_ap_values = np.sum(qt_values_sp * a_primes_one_hot, axis=-1) else: qt_sp_ap_values = np.max(qt_values_sp, axis=-1) for _ in range(qt_sp_ap_values.ndim - 1): rewards = np.expand_dims(rewards, axis=1) qt_sp_ap_values = np.where(terminals, np.zeros_like(qt_sp_ap_values), qt_sp_ap_values) elif get_backend() == "tf": # Make sure the target policy's outputs are treated as constant when calculating gradients. qt_values_sp = tf.stop_gradient(qt_values_sp) if self.double_q: # For double-Q, we no longer use the max(a')Qt(s'a') value. # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net! a_primes = tf.argmax(input=q_values_sp, axis=-1) # Now lookup Q(s'a') with the calculated a'. one_hot = tf.one_hot( indices=a_primes, depth=self.flat_action_space[key].num_categories) qt_sp_ap_values = tf.reduce_sum(input_tensor=(qt_values_sp * one_hot), axis=-1) else: # Qt(s',a') -> Use the max(a') value (from the target network). qt_sp_ap_values = tf.reduce_max(input_tensor=qt_values_sp, axis=-1) # Make sure the rewards vector (batch) is broadcast correctly. for _ in range(get_rank(qt_sp_ap_values) - 1): rewards = tf.expand_dims(rewards, axis=1) # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'. # Note that in that case, the next_state (s') is not the correct next state and should be disregarded. # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis. qt_sp_ap_values = tf.where(condition=terminals, x=tf.zeros_like(qt_sp_ap_values), y=qt_sp_ap_values) elif get_backend() == "pytorch": if not isinstance(terminals, torch.ByteTensor): terminals = terminals.byte() # Add batch dim in case of single sample. if qt_values_sp.dim() == 1: rewards = rewards.unsqueeze(-1) terminals = terminals.unsqueeze(-1) q_values_sp = q_values_sp.unsqueeze(-1) qt_values_sp = qt_values_sp.unsqueeze(-1) # Make sure the target policy's outputs are treated as constant when calculating gradients. qt_values_sp = qt_values_sp.detach() if self.double_q: # For double-Q, we no longer use the max(a')Qt(s'a') value. # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net! a_primes = torch.argmax(q_values_sp, dim=-1, keepdim=True) # Now lookup Q(s'a') with the calculated a'. one_hot = pytorch_one_hot( a_primes, depth=self.flat_action_space[key].num_categories) qt_sp_ap_values = torch.sum(qt_values_sp * one_hot.squeeze(), dim=-1) else: # Qt(s',a') -> Use the max(a') value (from the target network). qt_sp_ap_values = torch.max(qt_values_sp, -1)[0] # Make sure the rewards vector (batch) is broadcast correctly. for _ in range(get_rank(qt_sp_ap_values) - 1): rewards = torch.unsqueeze(rewards, dim=1) # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'. # Note that in that case, the next_state (s') is not the correct next state and should be disregarded. # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis. # torch.where cannot broadcast here, so tile and reshape to same shape. if qt_sp_ap_values.dim() > 1: num_tiles = np.prod(qt_sp_ap_values.shape[1:]) terminals = pytorch_tile(terminals, num_tiles, -1).reshape(qt_sp_ap_values.shape) qt_sp_ap_values = torch.where(terminals, torch.zeros_like(qt_sp_ap_values), qt_sp_ap_values) td_targets = (rewards + (self.discount**self.n_step) * qt_sp_ap_values) return td_targets
def _graph_fn_apply(self, key, preprocessing_inputs, input_before_time_rank_folding=None): """ Reshapes the input to the specified new shape. Args: preprocessing_inputs (SingleDataOp): The input to reshape. input_before_time_rank_folding (Optional[SingleDataOp]): The original input (before!) the time-rank had been folded (this was done in a different ReShape Component). Serves if `self.unfold_time_rank` is True to figure out the exact time-rank dimension to unfold. Returns: SingleDataOp: The reshaped input. """ assert self.unfold_time_rank is False or input_before_time_rank_folding is not None if self.backend == "python" or get_backend() == "python": # Create a one-hot axis for the categories at the end? num_categories = self.get_num_categories( key, get_space_from_op(preprocessing_inputs)) if num_categories and num_categories > 1: preprocessing_inputs = one_hot(preprocessing_inputs, depth=num_categories) if self.unfold_time_rank: new_shape = (-1, -1) + preprocessing_inputs.shape[1:] elif self.fold_time_rank: new_shape = (-1, ) + preprocessing_inputs.shape[2:] else: new_shape = self.get_preprocessed_space( get_space_from_op(preprocessing_inputs)).get_shape( with_batch_rank=-1, with_time_rank=-1) # Dynamic new shape inference: # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common # batch+time 0th rank, get these two dynamically. if len(preprocessing_inputs.shape ) > 2 and new_shape[0] == -1 and new_shape[1] == -1: # Time rank unfolding. Get the time rank from original input. if self.unfold_time_rank is True: original_shape = input_before_time_rank_folding.shape new_shape = (original_shape[0], original_shape[1]) + new_shape[2:] # No time-rank unfolding, but we do have both batch- and time-rank. else: input_shape = preprocessing_inputs.shape # Batch and time rank stay as is. new_shape = (input_shape[0], input_shape[1]) + new_shape[2:] return np.reshape(preprocessing_inputs, newshape=new_shape) elif get_backend() == "pytorch": # Create a one-hot axis for the categories at the end? num_categories = self.get_num_categories( key, get_space_from_op(preprocessing_inputs)) if num_categories and num_categories > 1: preprocessing_inputs = pytorch_one_hot(preprocessing_inputs, depth=num_categories) if self.unfold_time_rank: new_shape = (-1, -1) + preprocessing_inputs.shape[1:] elif self.fold_time_rank: new_shape = (-1, ) + preprocessing_inputs.shape[2:] else: new_shape = self.get_preprocessed_space( get_space_from_op(preprocessing_inputs)).get_shape( with_batch_rank=-1, with_time_rank=-1) # Dynamic new shape inference: # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common # batch+time 0th rank, get these two dynamically. if len(new_shape ) > 2 and new_shape[0] == -1 and new_shape[1] == -1: # Time rank unfolding. Get the time rank from original input. if self.unfold_time_rank is True: original_shape = input_before_time_rank_folding.shape new_shape = (original_shape[0], original_shape[1]) + new_shape[2:] # No time-rank unfolding, but we do have both batch- and time-rank. else: input_shape = preprocessing_inputs.shape # Batch and time rank stay as is. new_shape = (input_shape[0], input_shape[1]) + new_shape[2:] # print("Reshaping input of shape {} to new shape {} (flatten = {})".format(preprocessing_inputs.shape, # new_shape, self.flatten)) old_size = np.prod(list(preprocessing_inputs.shape)) new_size = np.prod(new_shape) # The problem here is the following: Input has dim e.g. [4, 256, 1, 1] # -> If shape inference in spaces failed, output dim is not correct -> reshape will attempt # something like reshaping to [256]. if self.flatten and preprocessing_inputs.dim() > 1: flattened_shape_without_batchrank = np.prod( preprocessing_inputs.shape[1:]) flattened_shape = (preprocessing_inputs.shape[0], ) + ( flattened_shape_without_batchrank, ) return torch.reshape(preprocessing_inputs, flattened_shape) # If new shape does not fit into old shape, batch inference failed -> try to restore: # Equal except batch rank -> return as is: elif old_size != new_size: if tuple(preprocessing_inputs.shape[1:]) == new_shape: return preprocessing_inputs else: # Attempt to rescue reshape by combining new shape with batch dim. full_new_shape = ( preprocessing_inputs.shape[0], ) + new_shape return torch.reshape(preprocessing_inputs, full_new_shape) else: return torch.reshape(preprocessing_inputs, new_shape) elif get_backend() == "tf": # Create a one-hot axis for the categories at the end? space = get_space_from_op(preprocessing_inputs) num_categories = self.get_num_categories(key, space) if num_categories and num_categories > 1: preprocessing_inputs_ = tf.one_hot(preprocessing_inputs, depth=num_categories, axis=-1, dtype="float32") if hasattr(preprocessing_inputs, "_batch_rank"): preprocessing_inputs_._batch_rank = preprocessing_inputs._batch_rank if hasattr(preprocessing_inputs, "_time_rank"): preprocessing_inputs_._time_rank = preprocessing_inputs._time_rank preprocessing_inputs = preprocessing_inputs_ if self.unfold_time_rank: list_shape = preprocessing_inputs.shape.as_list() assert len(list_shape) == 1 or list_shape[1] is not None,\ "ERROR: Cannot unfold. `preprocessing_inputs` (with shape {}) " \ "already seems to be unfolded!".format(list_shape) new_shape = (-1, -1) + tuple(list_shape[1:]) elif self.fold_time_rank: new_shape = (-1, ) + tuple( preprocessing_inputs.shape.as_list()[2:]) else: new_shape = self.get_preprocessed_space( get_space_from_op(preprocessing_inputs)).get_shape( with_batch_rank=-1, with_time_rank=-1) # Dynamic new shape inference: # If both batch and time rank must be left alone OR the time rank must be unfolded from a currently common # batch+time 0th rank, get these two dynamically. if len(new_shape ) >= 2 and new_shape[0] == -1 and new_shape[1] == -1: # Time rank unfolding. Get the time rank from original input. if self.unfold_time_rank is True: original_shape = tf.shape(input_before_time_rank_folding) new_shape = (original_shape[0], original_shape[1]) + new_shape[2:] # No time-rank unfolding, but we do have both batch- and time-rank. else: input_shape = tf.shape(preprocessing_inputs) # Batch and time rank stay as is. new_shape = (input_shape[0], input_shape[1]) + new_shape[2:] reshaped = tf.reshape(tensor=preprocessing_inputs, shape=new_shape, name="reshaped") # Have to place the time rank back in as unknown (for the auto Space inference). if type(self.unfold_time_rank) == int: # TODO: replace placeholder with default value by _batch_rank/_time_rank properties. return tf.placeholder_with_default(reshaped, shape=(None, None) + new_shape[2:]) else: # TODO: add other cases of reshaping and fix batch/time rank hints. if self.fold_time_rank: reshaped._batch_rank = 0 elif self.unfold_time_rank: reshaped._batch_rank = 1 if self.time_major is True else 0 reshaped._time_rank = 0 if self.time_major is True else 1 else: if space.has_batch_rank is True: if space.time_major is False: reshaped._batch_rank = 0 else: reshaped._time_rank = 0 reshaped._batch_rank = 1 if space.has_time_rank is True: reshaped._time_rank = 0 if space.time_major is True else 1 return reshaped
def _graph_fn_loss_per_item(self, q_values_s, actions, rewards, terminals, qt_values_sp, q_values_sp=None, importance_weights=None): """ Args: q_values_s (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns when in s and taking different actions a. actions (SingleDataOp): The batch of actions that were actually taken in states s (from a memory). rewards (SingleDataOp): The batch of rewards that we received after having taken a in s (from a memory). terminals (SingleDataOp): The batch of terminal signals that we received after having taken a in s (from a memory). qt_values_sp (SingleDataOp): The batch of Q-values representing the expected accumulated discounted returns (estimated by the target net) when in s' and taking different actions a'. q_values_sp (Optional[SingleDataOp]): If `self.double_q` is True: The batch of Q-values representing the expected accumulated discounted returns (estimated by the (main) policy net) when in s' and taking different actions a'. importance_weights (Optional[SingleDataOp]): If 'self.importance_weights' is True: The batch of weights to apply to the losses. Returns: SingleDataOp: The loss values vector (one single value for each batch item). """ # Numpy backend primarily for testing purposes. if self.backend == "python" or get_backend() == "python": from rlgraph.utils.numpy import one_hot if self.double_q: a_primes = np.argmax(q_values_sp, axis=-1) a_primes_one_hot = one_hot( a_primes, depth=self.action_space.num_categories) qt_sp_ap_values = np.sum(qt_values_sp * a_primes_one_hot, axis=-1) else: qt_sp_ap_values = np.max(qt_values_sp, axis=-1) for _ in range(qt_sp_ap_values.ndim - 1): rewards = np.expand_dims(rewards, axis=1) qt_sp_ap_values = np.where(terminals, np.zeros_like(qt_sp_ap_values), qt_sp_ap_values) actions_one_hot = one_hot(actions, depth=self.action_space.num_categories) q_s_a_values = np.sum(q_values_s * actions_one_hot, axis=-1) td_delta = ( rewards + (self.discount**self.n_step) * qt_sp_ap_values) - q_s_a_values if td_delta.ndim > 1: if self.importance_weights: td_delta = np.mean(td_delta * importance_weights, axis=list( range(1, self.ranks_to_reduce + 1))) else: td_delta = np.mean(td_delta, axis=list( range(1, self.ranks_to_reduce + 1))) return self._apply_huber_loss_if_necessary(td_delta) elif get_backend() == "tf": # Make sure the target policy's outputs are treated as constant when calculating gradients. qt_values_sp = tf.stop_gradient(qt_values_sp) if self.double_q: # For double-Q, we no longer use the max(a')Qt(s'a') value. # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net! a_primes = tf.argmax(input=q_values_sp, axis=-1) # Now lookup Q(s'a') with the calculated a'. one_hot = tf.one_hot(indices=a_primes, depth=self.action_space.num_categories) qt_sp_ap_values = tf.reduce_sum(input_tensor=(qt_values_sp * one_hot), axis=-1) else: # Qt(s',a') -> Use the max(a') value (from the target network). qt_sp_ap_values = tf.reduce_max(input_tensor=qt_values_sp, axis=-1) # Make sure the rewards vector (batch) is broadcast correctly. for _ in range(get_rank(qt_sp_ap_values) - 1): rewards = tf.expand_dims(rewards, axis=1) # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'. # Note that in that case, the next_state (s') is not the correct next state and should be disregarded. # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis. qt_sp_ap_values = tf.where(condition=terminals, x=tf.zeros_like(qt_sp_ap_values), y=qt_sp_ap_values) # Q(s,a) -> Use the Q-value of the action actually taken before. one_hot = tf.one_hot(indices=actions, depth=self.action_space.num_categories) q_s_a_values = tf.reduce_sum(input_tensor=(q_values_s * one_hot), axis=-1) # Calculate the TD-delta (target - current estimate). td_delta = ( rewards + (self.discount**self.n_step) * qt_sp_ap_values) - q_s_a_values # Reduce over the composite actions, if any. if get_rank(td_delta) > 1: td_delta = tf.reduce_mean(input_tensor=td_delta, axis=list( range(1, self.ranks_to_reduce + 1))) # Apply importance-weights from a prioritized replay to the loss. if self.importance_weights: return importance_weights * self._apply_huber_loss_if_necessary( td_delta) else: return self._apply_huber_loss_if_necessary(td_delta) elif get_backend() == "pytorch": if not isinstance(terminals, torch.ByteTensor): terminals = terminals.byte() # Add batch dim in case of single sample. if q_values_s.dim() == 1: q_values_s = q_values_s.unsqueeze(-1) actions = actions.unsqueeze(-1) rewards = rewards.unsqueeze(-1) terminals = terminals.unsqueeze(-1) q_values_sp = q_values_sp.unsqueeze(-1) qt_values_sp = qt_values_sp.unsqueeze(-1) if self.importance_weights: importance_weights = importance_weights.unsqueeze(-1) # Make sure the target policy's outputs are treated as constant when calculating gradients. qt_values_sp = qt_values_sp.detach() if self.double_q: # For double-Q, we no longer use the max(a')Qt(s'a') value. # Instead, the a' used to get the Qt(s'a') is given by argmax(a') Q(s',a') <- Q=q-net, not target net! a_primes = torch.argmax(q_values_sp, dim=-1, keepdim=True) # Now lookup Q(s'a') with the calculated a'. one_hot = pytorch_one_hot( a_primes, depth=self.action_space.num_categories) qt_sp_ap_values = torch.sum(qt_values_sp * one_hot, dim=-1) else: # Qt(s',a') -> Use the max(a') value (from the target network). qt_sp_ap_values = torch.max(qt_values_sp) # Make sure the rewards vector (batch) is broadcast correctly. for _ in range(get_rank(qt_sp_ap_values) - 1): rewards = torch.unsqueeze(rewards, dim=1) # Ignore Q(s'a') values if s' is a terminal state. Instead use 0.0 as the state-action value for s'a'. # Note that in that case, the next_state (s') is not the correct next state and should be disregarded. # See Chapter 3.4 in "RL - An Introduction" (2017 draft) by A. Barto and R. Sutton for a detailed analysis. qt_sp_ap_values = torch.where(terminals, torch.zeros_like(qt_sp_ap_values), qt_sp_ap_values) # Q(s,a) -> Use the Q-value of the action actually taken before. one_hot = pytorch_one_hot(actions, depth=self.action_space.num_categories) q_s_a_values = torch.sum((q_values_s * one_hot), -1) # Calculate the TD-delta (target - current estimate). td_delta = ( rewards + (self.discount**self.n_step) * qt_sp_ap_values) - q_s_a_values # Reduce over the composite actions, if any. if get_rank(td_delta) > 1: td_delta = pytorch_reduce_mean( td_delta, list(range(1, self.ranks_to_reduce + 1)), keepdims=False) # Apply importance-weights from a prioritized replay to the loss. if self.importance_weights: return importance_weights * self._apply_huber_loss_if_necessary( td_delta) else: return self._apply_huber_loss_if_necessary(td_delta)