def distribution( self, time_step: ts.TimeStep, policy_state: types.NestedTensor = () ) -> policy_step.PolicyStep: """Generates the distribution over next actions given the time_step. Args: time_step: A `TimeStep` tuple corresponding to `time_step_spec()`. policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. Returns: A `PolicyStep` named tuple containing: `action`: A tf.distribution capturing the distribution of next actions. `state`: A policy state tensor for the next call to distribution. `info`: Optional side information such as action log probabilities. Raises: ValueError or TypeError: If `validate_args is True` and inputs or outputs do not match `time_step_spec`, `policy_state_spec`, or `policy_step_spec`. """ if self._validate_args: time_step = nest_utils.prune_extra_keys(self._time_step_spec, time_step) policy_state = nest_utils.prune_extra_keys(self._policy_state_spec, policy_state) nest_utils.assert_same_structure( time_step, self._time_step_spec, message='time_step and time_step_spec structures do not match') nest_utils.assert_same_structure( policy_state, self._policy_state_spec, message= 'policy_state and policy_state_spec structures do not match') if self._automatic_state_reset: policy_state = self._maybe_reset_state(time_step, policy_state) step = self._distribution(time_step=time_step, policy_state=policy_state) if self.emit_log_probability: # This here is set only for compatibility with info_spec in constructor. info = policy_step.set_log_probability( step.info, tf.nest.map_structure( lambda _: tf.constant(0., dtype=tf.float32), policy_step.get_log_probability(self._info_spec))) step = step._replace(info=info) if self._validate_args: nest_utils.assert_same_structure( step, self._policy_step_spec, message=('distribution output and policy_step_spec structures ' 'do not match')) return step
def _assert_nested_variable_updated( self, variables: types.NestedVariable, check_nest_seq_types: bool = True) -> None: # Prepare the exptected content of the variables. expected_values = (tf.constant(0, dtype=tf.int64, shape=()), { 'var1': (tf.constant([1, 1], dtype=tf.float64, shape=(2,)),), 'var2': tf.constant([[2], [3]], dtype=tf.int32, shape=(2, 1)) }) flat_expected_values = tf.nest.flatten(expected_values) # Assert that the variables have the same content as the expected values. # Meaning that the two nested structure have to be the same. self.assertIsNone( nest_utils.assert_same_structure( variables, expected_values, check_types=check_nest_seq_types)) # And the values in `variables` have to be equal to (or close to, depending # on the component type) to the expected ones. flat_variables = tf.nest.flatten(variables) self.assertAllEqual(flat_variables[0], flat_expected_values[0]) self.assertAllClose(flat_variables[1], flat_expected_values[1]) self.assertAllEqual(flat_variables[2], flat_expected_values[2])
def critic_no_entropy_loss(self, time_steps, actions, next_time_steps, td_errors_loss_fn, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used for training. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_no_entropy_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) next_actions, _ = self._actions_and_log_probs(next_time_steps) target_input = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network_no_entropy_1( target_input, next_time_steps.step_type, training=False) target_q_values2, unused_network_state2 = self._target_critic_network_no_entropy_2( target_input, next_time_steps.step_type, training=False) target_q_values = tf.minimum( target_q_values1, target_q_values2 ) # entropy has been removed from the target critic function td_targets = tf.stop_gradient( reward_scale_factor * next_time_steps.reward + gamma * next_time_steps.discount * target_q_values) pred_input = (time_steps.observation, actions) pred_td_targets1, _ = self._critic_network_no_entropy_1( pred_input, time_steps.step_type, training=training) pred_td_targets2, _ = self._critic_network_no_entropy_2( pred_input, time_steps.step_type, training=training) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Sum over the time dimension. critic_loss = tf.reduce_sum(input_tensor=critic_loss, axis=1) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, sample_weight=weights, regularization_loss=(self._critic_network_no_entropy_1.losses + self._critic_network_no_entropy_2.losses)) critic_no_entropy_loss = agg_loss.total_loss self._critic_no_entropy_loss_debug_summaries( td_targets, pred_td_targets1, pred_td_targets2) return critic_no_entropy_loss
def action(self, time_step, policy_state=(), seed=None): """Generates next action given the time_step and policy_state. Args: time_step: A `TimeStep` tuple corresponding to `time_step_spec()`. policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. seed: Seed to use if action performs sampling (optional). Returns: A `PolicyStep` named tuple containing: `action`: An action Tensor matching the `action_spec`. `state`: A policy state tensor to be fed into the next call to action. `info`: Optional side information such as action log probabilities. Raises: RuntimeError: If subclass __init__ didn't call super().__init__. ValueError or TypeError: If `validate_args is True` and inputs or outputs do not match `time_step_spec`, `policy_state_spec`, or `policy_step_spec`. """ if self._enable_functions and getattr(self, '_action_fn', None) is None: raise RuntimeError( 'Cannot find _action_fn. Did %s.__init__ call super?' % type(self).__name__) if self._enable_functions: action_fn = self._action_fn else: action_fn = self._action if self._validate_args: time_step = nest_utils.prune_extra_keys(self._time_step_spec, time_step) policy_state = nest_utils.prune_extra_keys(self._policy_state_spec, policy_state) nest_utils.assert_same_structure( time_step, self._time_step_spec, message='time_step and time_step_spec structures do not match') nest_utils.assert_same_structure( policy_state, self._policy_state_spec, message= 'policy_state and policy_state_spec structures do not match') if self._automatic_state_reset: policy_state = self._maybe_reset_state(time_step, policy_state) step = action_fn(time_step=time_step, policy_state=policy_state, seed=seed) def clip_action(action, action_spec): if isinstance(action_spec, tensor_spec.BoundedTensorSpec): return common.clip_to_spec(action, action_spec) return action if self._validate_args: nest_utils.assert_same_structure( step.action, self._action_spec, message='action and action_spec structures do not match') if self._clip: clipped_actions = tf.nest.map_structure(clip_action, step.action, self._action_spec) step = step._replace(action=clipped_actions) if self._validate_args: nest_utils.assert_same_structure( step, self._policy_step_spec, message= 'action output and policy_step_spec structures do not match') def compare_to_spec(value, spec): return value.dtype.is_compatible_with(spec.dtype) compatibility = [ compare_to_spec(v, s) for (v, s) in zip(tf.nest.flatten(step.action), tf.nest.flatten(self.action_spec)) ] if not all(compatibility): get_dtype = lambda x: x.dtype action_dtypes = tf.nest.map_structure(get_dtype, step.action) spec_dtypes = tf.nest.map_structure(get_dtype, self.action_spec) raise TypeError( 'Policy produced an action with a dtype that doesn\'t ' 'match its action_spec. Got action:\n %s\n with ' 'action_spec:\n %s' % (action_dtypes, spec_dtypes)) return step
def as_dict(outputs, output_spec): nest_utils.assert_same_structure(outputs, output_spec) flat_outputs = tf.nest.flatten(outputs) flat_names = [s.name for s in tf.nest.flatten(output_spec)] return dict(zip(flat_names, flat_outputs))
def _loss(self, experience, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes critic loss for CategoricalDQN training. See Algorithm 1 and the discussion immediately preceding it in page 6 of "A Distributional Perspective on Reinforcement Learning" Bellemare et al., 2017 https://arxiv.org/abs/1707.06887 Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.policy.step_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.required_experience_time_steps` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional weights used for importance sampling. training: Whether the loss is being used for training. Returns: critic_loss: A scalar critic loss. Raises: ValueError: if the number of actions is greater than 1. """ # Check that `experience` includes two outer dimensions [B, T, ...]. This # method requires a time dimension to compute the loss properly. self._check_trajectory_dimensions(experience) squeeze_time_dim = not self._q_network.state_spec if self._n_step_update == 1: time_steps, policy_steps, next_time_steps = ( trajectory.experience_to_transitions(experience, squeeze_time_dim)) actions = policy_steps.action else: # To compute n-step returns, we need the first time steps, the first # actions, and the last time steps. Therefore we extract the first and # last transitions from our Trajectory. first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience) last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience) time_steps, policy_steps, _ = ( trajectory.experience_to_transitions(first_two_steps, squeeze_time_dim)) actions = policy_steps.action _, _, next_time_steps = (trajectory.experience_to_transitions( last_two_steps, squeeze_time_dim)) with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) rank = nest_utils.get_outer_rank(time_steps.observation, self._time_step_spec.observation) # If inputs have a time dimension and the q_network is stateful, # combine the batch and time dimension. batch_squash = (None if rank <= 1 or self._q_network.state_spec in ((), None) else network_utils.BatchSquash(rank)) network_observation = time_steps.observation if self._observation_and_action_constraint_splitter is not None: network_observation, _ = ( self._observation_and_action_constraint_splitter( network_observation)) # q_logits contains the Q-value logits for all actions. q_logits, _ = self._q_network(network_observation, time_steps.step_type, training=training) if batch_squash is not None: # Squash outer dimensions to a single dimensions for facilitation # computing the loss the following. Required for supporting temporal # inputs, for example. q_logits = batch_squash.flatten(q_logits) actions = batch_squash.flatten(actions) next_time_steps = tf.nest.map_structure( batch_squash.flatten, next_time_steps) next_q_distribution = self._next_q_distribution(next_time_steps) if actions.shape.rank > 1: actions = tf.squeeze(actions, list(range(1, actions.shape.rank))) # Project the sample Bellman update \hat{T}Z_{\theta} onto the original # support of Z_{\theta} (see Figure 1 in paper). batch_size = q_logits.shape[0] or tf.shape(q_logits)[0] tiled_support = tf.tile(self._support, [batch_size]) tiled_support = tf.reshape(tiled_support, [batch_size, self._num_atoms]) if self._n_step_update == 1: discount = next_time_steps.discount if discount.shape.rank == 1: # We expect discount to have a shape of [batch_size], while # tiled_support will have a shape of [batch_size, num_atoms]. To # multiply these, we add a second dimension of 1 to the discount. discount = tf.expand_dims(discount, -1) next_value_term = tf.multiply(discount, tiled_support, name='next_value_term') reward = next_time_steps.reward if reward.shape.rank == 1: # See the explanation above. reward = tf.expand_dims(reward, -1) reward_term = tf.multiply(reward_scale_factor, reward, name='reward_term') target_support = tf.add(reward_term, gamma * next_value_term, name='target_support') else: # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. rewards = reward_scale_factor * experience.reward[:, :-1] discounts = gamma * experience.discount[:, :-1] # TODO(b/134618876): Properly handle Trajectories that include episode # boundaries with nonzero discount. discounted_returns = value_ops.discounted_return( rewards=rewards, discounts=discounts, final_value=tf.zeros([batch_size], dtype=discounts.dtype), time_major=False, provide_all_returns=False) # Convert discounted_returns from [batch_size] to [batch_size, 1] discounted_returns = tf.expand_dims(discounted_returns, -1) final_value_discount = tf.reduce_prod(discounts, axis=1) final_value_discount = tf.expand_dims(final_value_discount, -1) # Save the values of discounted_returns and final_value_discount in # order to check them in unit tests. self._discounted_returns = discounted_returns self._final_value_discount = final_value_discount target_support = tf.add(discounted_returns, final_value_discount * tiled_support, name='target_support') target_distribution = tf.stop_gradient( project_distribution(target_support, next_q_distribution, self._support)) # Obtain the current Q-value logits for the selected actions. indices = tf.range(batch_size) indices = tf.cast(indices, actions.dtype) reshaped_actions = tf.stack([indices, actions], axis=-1) chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions) # Compute the cross-entropy loss between the logits. If inputs have # a time dimension, compute the sum over the time dimension before # computing the mean over the batch dimension. if batch_squash is not None: target_distribution = batch_squash.unflatten( target_distribution) chosen_action_logits = batch_squash.unflatten( chosen_action_logits) critic_loss = tf.reduce_sum( tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( labels=target_distribution, logits=chosen_action_logits), axis=1) else: critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( labels=target_distribution, logits=chosen_action_logits) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss dict_losses = { 'critic_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(dict_losses, step=self.train_step_counter, name_scope='Losses/') if self._debug_summaries: distribution_errors = target_distribution - chosen_action_logits with tf.name_scope('distribution_errors'): common.generate_tensor_summaries( 'distribution_errors', distribution_errors, step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean', tf.reduce_mean(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean_abs', tf.reduce_mean(tf.abs(distribution_errors)), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'max', tf.reduce_max(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'min', tf.reduce_min(distribution_errors), step=self.train_step_counter) with tf.name_scope('target_distribution'): common.generate_tensor_summaries( 'target_distribution', target_distribution, step=self.train_step_counter) # TODO(b/127318640): Give appropriate values for td_loss and td_error for # prioritized replay. return tf_agent.LossInfo( total_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
def critic_loss(self, time_steps, actions, next_time_steps, augmented_obs, augmented_next_obs, td_errors_loss_fn, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. augmented_obs: List of observations. augmented_next_obs: List of next_observations. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used for training. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) td_targets = self._compute_td_targets(next_time_steps, reward_scale_factor, gamma) # Compute td_targets with augmentations. for i in range(self._num_augmentations - 1): augmented_next_time_steps = next_time_steps._replace( observation=augmented_next_obs[i]) augmented_td_targets = self._compute_td_targets( augmented_next_time_steps, reward_scale_factor, gamma) td_targets = td_targets + augmented_td_targets # Average td_target estimation over augmentations. if self._num_augmentations > 1: td_targets = td_targets / self._num_augmentations pred_td_targets1, pred_td_targets2, critic_loss = ( self._compute_prediction_critic_loss( (time_steps.observation, actions), td_targets, time_steps, training, td_errors_loss_fn)) # Add Q Augmentations to the critic loss. for i in range(self._num_augmentations - 1): augmented_time_steps = time_steps._replace(observation=augmented_obs[i]) _, _, loss = ( self._compute_prediction_critic_loss( (augmented_time_steps.observation, actions), td_targets, augmented_time_steps, training, td_errors_loss_fn)) critic_loss = critic_loss + loss agg_loss = common.aggregate_losses( per_example_loss=critic_loss, sample_weight=weights, regularization_loss=(self._critic_network_1.losses + self._critic_network_2.losses)) critic_loss = agg_loss.total_loss self._critic_loss_debug_summaries(td_targets, pred_td_targets1, pred_td_targets2) return critic_loss
def critic_loss(self, time_steps, expert_experience, actions, next_time_steps, future_time_steps, td_errors_loss_fn, gamma = 1.0, reward_scale_factor = 1.0, weights = None, training = False, loss_name='c', use_done=False, q_combinator='min'): """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps. expert_experience: An array of success examples. actions: A batch of actions. next_time_steps: A batch of next timesteps. future_time_steps: A batch of future timesteps, used for n-step returns. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used for training. loss_name: Which loss function to use. Use 'c' for RCE and 'q' for SQIL. use_done: Whether to use the terminal flag from the environment in the Bellman backup. We found that omitting it led to better results. q_combinator: Whether to combine the two Q-functions by taking the 'min' (as in TD3) or the 'max'. Returns: critic_loss: A scalar critic loss. """ assert weights is None with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) next_actions, _ = self._actions_and_log_probs(next_time_steps) target_input = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network_1( target_input, next_time_steps.step_type, training=False) target_q_values2, unused_network_state2 = self._target_critic_network_2( target_input, next_time_steps.step_type, training=False) if self._n_step is not None: future_actions, _ = self._actions_and_log_probs(future_time_steps) future_input = (future_time_steps.observation, future_actions) future_q_values1, _ = self._target_critic_network_1( future_input, future_time_steps.step_type, training=False) future_q_values2, _ = self._target_critic_network_2( future_input, future_time_steps.step_type, training=False) gamma_n = gamma**self._n_step # Discount for n-step returns target_q_values1 = (target_q_values1 + gamma_n * future_q_values1) / 2.0 target_q_values2 = (target_q_values2 + gamma_n * future_q_values2) / 2.0 if q_combinator == 'min': target_q_values = tf.minimum(target_q_values1, target_q_values2) else: assert q_combinator == 'max' target_q_values = tf.maximum(target_q_values1, target_q_values2) batch_size = time_steps.observation.shape[0] if loss_name == 'q': if use_done: td_targets = gamma * next_time_steps.discount * target_q_values else: td_targets = gamma * target_q_values else: assert loss_name == 'c' w = target_q_values / (1 - target_q_values) td_targets = gamma * w / (gamma * w + 1) if use_done: td_targets = next_time_steps.discount * td_targets weights = tf.concat([1 + gamma * w, (1 - gamma) * tf.ones(batch_size)], axis=0) td_targets = tf.stop_gradient(td_targets) td_targets = tf.concat([td_targets, tf.ones(batch_size)], axis=0) # Note that the actions only depend on the observations. We create the # expert_time_steps object simply to make this look like a time step # object. expert_time_steps = time_steps._replace(observation=expert_experience) if self._use_behavior_policy: policy_state = self._train_policy.get_initial_state(batch_size) action_distribution = self._behavior_policy.distribution( time_steps, policy_state=policy_state).action # Sample actions and log_pis from transformed distribution. expert_actions = tf.nest.map_structure(lambda d: d.sample(), action_distribution) else: expert_actions, _ = self._actions_and_log_probs(expert_time_steps) observation = time_steps.observation pred_input = (tf.concat([observation, expert_experience], axis=0), tf.concat([actions, expert_actions], axis=0)) pred_td_targets1, _ = self._critic_network_1( pred_input, time_steps.step_type, training=training) pred_td_targets2, _ = self._critic_network_2( pred_input, time_steps.step_type, training=training) self._critic_loss_debug_summaries(td_targets, pred_td_targets1, pred_td_targets2) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if critic_loss.shape.rank > 1: # Sum over the time dimension. critic_loss = tf.reduce_sum( critic_loss, axis=range(1, critic_loss.shape.rank)) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, sample_weight=weights, regularization_loss=(self._critic_network_1.losses + self._critic_network_2.losses)) critic_loss = agg_loss.total_loss self._critic_loss_debug_summaries(td_targets, pred_td_targets1, pred_td_targets2) return critic_loss
def run(self, trajectory, policy_state=None): """Apply the policy to trajectory steps and store actions/info. If `self.time_major == True`, the tensors in `trajectory` are assumed to have shape `[time, batch, ...]`. Otherwise they are assumed to have shape `[batch, time, ...]`. Args: trajectory: The `Trajectory` to run against. If the replay class was created with `time_major=True`, then the tensors in trajectory must be shaped `[time, batch, ...]`. Otherwise they must be shaped `[batch, time, ...]`. policy_state: (optional) A nest Tensor with initial step policy state. Returns: output_actions: A nest of the actions that the policy took. If the replay class was created with `time_major=True`, then the tensors here will be shaped `[time, batch, ...]`. Otherwise they'll be shaped `[batch, time, ...]`. output_policy_info: A nest of the policy info that the policy emitted. If the replay class was created with `time_major=True`, then the tensors here will be shaped `[time, batch, ...]`. Otherwise they'll be shaped `[batch, time, ...]`. policy_state: A nest Tensor with final step policy state. Raises: TypeError: If `policy_state` structure doesn't match `self.policy.policy_state_spec`, or `trajectory` structure doesn't match `self.policy.trajectory_spec`. ValueError: If `policy_state` doesn't match `self.policy.policy_state_spec`, or `trajectory` structure doesn't match `self.policy.trajectory_spec`. ValueError: If `trajectory` lacks two outer dims. """ trajectory_spec = self._policy.trajectory_spec outer_dims = nest_utils.get_outer_shape(trajectory, trajectory_spec) if tf.compat.dimension_value(outer_dims.shape[0]) != 2: raise ValueError( "Expected two outer dimensions, but saw '{}' dimensions.\n" "Trajectory:\n{}.\nTrajectory spec from policy:\n{}.".format( tf.compat.dimension_value(outer_dims.shape[0]), trajectory, trajectory_spec)) if self._time_major: sequence_length = outer_dims[0] batch_size = outer_dims[1] static_batch_size = tf.compat.dimension_value( trajectory.discount.shape[1]) else: batch_size = outer_dims[0] sequence_length = outer_dims[1] static_batch_size = tf.compat.dimension_value( trajectory.discount.shape[0]) if policy_state is None: policy_state = self._policy.get_initial_state(batch_size) else: nest_utils.assert_same_structure(policy_state, self._policy.policy_state_spec) if not self._time_major: # Make trajectory time-major. trajectory = tf.nest.map_structure(common.transpose_batch_time, trajectory) trajectory_tas = tf.nest.map_structure( lambda t: tf.TensorArray(t.dtype, size=sequence_length).unstack(t), trajectory) def create_output_ta(spec): return tf.TensorArray(spec.dtype, size=sequence_length, element_shape=(tf.TensorShape([ static_batch_size ]).concatenate(spec.shape))) output_action_tas = tf.nest.map_structure(create_output_ta, trajectory_spec.action) output_policy_info_tas = tf.nest.map_structure( create_output_ta, trajectory_spec.policy_info) read0 = lambda ta: ta.read(0) zeros_like0 = lambda t: tf.zeros_like(t[0]) ones_like0 = lambda t: tf.ones_like(t[0]) time_step = ts.TimeStep( step_type=read0(trajectory_tas.step_type), reward=tf.nest.map_structure(zeros_like0, trajectory.reward), discount=ones_like0(trajectory.discount), observation=tf.nest.map_structure(read0, trajectory_tas.observation)) def process_step(time, time_step, policy_state, output_action_tas, output_policy_info_tas): """Take an action on the given step, and update output TensorArrays. Args: time: Step time. Describes which row to read from the trajectory TensorArrays and which location to write into in the output TensorArrays. time_step: Previous step's `TimeStep`. policy_state: Policy state tensor or nested structure of tensors. output_action_tas: Nest of `tf.TensorArray` containing new actions. output_policy_info_tas: Nest of `tf.TensorArray` containing new policy info. Returns: policy_state: The next policy state. next_output_action_tas: Updated `output_action_tas`. next_output_policy_info_tas: Updated `output_policy_info_tas`. """ action_step = self._policy.action(time_step, policy_state) policy_state = action_step.state write_ta = lambda ta, t: ta.write(time - 1, t) next_output_action_tas = tf.nest.map_structure( write_ta, output_action_tas, action_step.action) next_output_policy_info_tas = tf.nest.map_structure( write_ta, output_policy_info_tas, action_step.info) return (action_step.state, next_output_action_tas, next_output_policy_info_tas) def loop_body(time, time_step, policy_state, output_action_tas, output_policy_info_tas): """Runs a step in environment. While loop will call multiple times. Args: time: Step time. time_step: Previous step's `TimeStep`. policy_state: Policy state tensor or nested structure of tensors. output_action_tas: Updated nest of `tf.TensorArray`, the new actions. output_policy_info_tas: Updated nest of `tf.TensorArray`, the new policy info. Returns: loop_vars for next iteration of tf.while_loop. """ policy_state, next_output_action_tas, next_output_policy_info_tas = ( process_step(time, time_step, policy_state, output_action_tas, output_policy_info_tas)) ta_read = lambda ta: ta.read(time) ta_read_prev = lambda ta: ta.read(time - 1) time_step = ts.TimeStep( step_type=ta_read(trajectory_tas.step_type), observation=tf.nest.map_structure(ta_read, trajectory_tas.observation), reward=tf.nest.map_structure(ta_read_prev, trajectory_tas.reward), discount=ta_read_prev(trajectory_tas.discount)) return (time + 1, time_step, policy_state, next_output_action_tas, next_output_policy_info_tas) time = tf.constant(1) time, time_step, policy_state, output_action_tas, output_policy_info_tas = ( tf.while_loop(cond=lambda time, *_: time < sequence_length, body=loop_body, loop_vars=[ time, time_step, policy_state, output_action_tas, output_policy_info_tas ], back_prop=False, name="trajectory_replay_loop")) # Run the last time step last_policy_state, output_action_tas, output_policy_info_tas = ( process_step(time, time_step, policy_state, output_action_tas, output_policy_info_tas)) def stack_ta(ta): t = ta.stack() if not self._time_major: t = common.transpose_batch_time(t) return t stacked_output_actions = tf.nest.map_structure(stack_ta, output_action_tas) stacked_output_policy_info = tf.nest.map_structure( stack_ta, output_policy_info_tas) return (stacked_output_actions, stacked_output_policy_info, last_policy_state)
def critic_loss(self, time_steps: ts.TimeStep, actions: types.Tensor, next_time_steps: ts.TimeStep, td_errors_loss_fn: types.LossFn, gamma: types.Float = 1.0, reward_scale_factor: types.Float = 1.0, weights: Optional[types.Tensor] = None, training: bool = False) -> types.Tensor: """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used for training. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) next_actions, next_log_pis = self._actions_and_log_probs( next_time_steps) target_input = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network_1( target_input, next_time_steps.step_type, training=False) target_q_values2, unused_network_state2 = self._target_critic_network_2( target_input, next_time_steps.step_type, training=False) target_q_values = (tf.minimum(target_q_values1, target_q_values2) - tf.exp(self._log_alpha) * next_log_pis) td_targets = tf.stop_gradient( reward_scale_factor * next_time_steps.reward + gamma * next_time_steps.discount * target_q_values) pred_input = (time_steps.observation, actions) pred_td_targets1, _ = self._critic_network_1(pred_input, time_steps.step_type, training=training) pred_td_targets2, _ = self._critic_network_2(pred_input, time_steps.step_type, training=training) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if critic_loss.shape.rank > 1: # Sum over the time dimension. critic_loss = tf.reduce_sum(critic_loss, axis=range(1, critic_loss.shape.rank)) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, sample_weight=weights, regularization_loss=(self._critic_network_1.losses + self._critic_network_2.losses)) critic_loss = agg_loss.total_loss self._critic_loss_debug_summaries(td_targets, pred_td_targets1, pred_td_targets2) return critic_loss
def total_loss(self, experience: traj.Trajectory, returns: types.Tensor, weights: types.Tensor, training: bool = False) -> tf_agent.LossInfo: # Ensure we see at least one full episode. time_steps = ts.TimeStep(experience.step_type, tf.zeros_like(experience.reward), tf.zeros_like(experience.discount), experience.observation) is_last = experience.is_last() num_episodes = tf.reduce_sum(tf.cast(is_last, tf.float32)) tf.debugging.assert_greater( num_episodes, 0.0, message= 'No complete episode found. REINFORCE requires full episodes ' 'to compute losses.') # Mask out partial episodes at the end of each batch of time_steps. # NOTE: We use is_last rather than is_boundary because the last transition # is the transition with the last valid reward. In other words, the # reward on the boundary transitions do not have valid rewards. Since # REINFORCE is calculating a loss w.r.t. the returns (and not bootstrapping) # keeping the boundary transitions is irrelevant. valid_mask = tf.cast(experience.is_last(), dtype=tf.float32) valid_mask = tf.math.cumsum(valid_mask, axis=1, reverse=True) valid_mask = tf.cast(valid_mask > 0, dtype=tf.float32) if weights is not None: weights *= valid_mask else: weights = valid_mask advantages = returns value_preds = None if self._baseline: value_preds, _ = self._value_network(time_steps.observation, time_steps.step_type, training=True) if self._debug_summaries: tf.compat.v2.summary.histogram(name='value_preds', data=value_preds, step=self.train_step_counter) advantages = self._advantage_fn(returns, value_preds) if self._debug_summaries: tf.compat.v2.summary.histogram(name='advantages', data=advantages, step=self.train_step_counter) # TODO(b/126592060): replace with tensor normalizer. if self._normalize_returns: advantages = _standard_normalize(advantages, axes=(0, 1)) if self._debug_summaries: tf.compat.v2.summary.histogram( name='normalized_%s' % ('advantages' if self._baseline else 'returns'), data=advantages, step=self.train_step_counter) nest_utils.assert_same_structure(time_steps, self.time_step_spec) policy_state = _get_initial_policy_state(self.collect_policy, time_steps) actions_distribution = self.collect_policy.distribution( time_steps, policy_state=policy_state).action policy_gradient_loss = self.policy_gradient_loss( actions_distribution, experience.action, experience.is_boundary(), advantages, num_episodes, weights, ) entropy_regularization_loss = self.entropy_regularization_loss( actions_distribution, weights) network_regularization_loss = tf.nn.scale_regularization_loss( self._actor_network.losses) total_loss = (policy_gradient_loss + network_regularization_loss + entropy_regularization_loss) losses_dict = { 'policy_gradient_loss': policy_gradient_loss, 'policy_network_regularization_loss': network_regularization_loss, 'entropy_regularization_loss': entropy_regularization_loss, 'value_estimation_loss': 0.0, 'value_network_regularization_loss': 0.0, } value_estimation_loss = None if self._baseline: value_estimation_loss = self.value_estimation_loss( value_preds, returns, num_episodes, weights) value_network_regularization_loss = tf.nn.scale_regularization_loss( self._value_network.losses) total_loss += value_estimation_loss + value_network_regularization_loss losses_dict['value_estimation_loss'] = value_estimation_loss losses_dict['value_network_regularization_loss'] = ( value_network_regularization_loss) loss_info_extra = ReinforceAgentLossInfo(**losses_dict) losses_dict[ 'total_loss'] = total_loss # Total loss not in loss_info_extra. common.summarize_scalar_dict(losses_dict, self.train_step_counter, name_scope='Losses/') return tf_agent.LossInfo(total_loss, loss_info_extra)
def critic_loss( self, time_steps, actions, next_time_steps, td_errors_loss_fn, gamma=1.0, weights=None, training=False, w_clipping=None, self_normalized=False, lambda_fix=False, ): """Computes the critic loss for C-learning training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used for training. w_clipping: Maximum value used for clipping the weights. Use -1 to do no clipping; use None to use the recommended value of 1 / (1 - gamma). self_normalized: Whether to normalize the weights to the average is 1. Empirically this usually hurts performance. lambda_fix: Whether to include the adjustment when using future positives. Empirically this has little effect. Returns: critic_loss: A scalar critic loss. """ del weights if w_clipping is None: w_clipping = 1 / (1 - gamma) rfp = gin.query_parameter('goal_fn.relabel_future_prob') rnp = gin.query_parameter('goal_fn.relabel_next_prob') assert rfp + rnp == 0.5 with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) next_actions, _ = self._actions_and_log_probs(next_time_steps) target_input = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network_1( target_input, next_time_steps.step_type, training=False) target_q_values2, unused_network_state2 = self._target_critic_network_2( target_input, next_time_steps.step_type, training=False) target_q_values = tf.minimum(target_q_values1, target_q_values2) w = tf.stop_gradient(target_q_values / (1 - target_q_values)) if w_clipping >= 0: w = tf.clip_by_value(w, 0, w_clipping) tf.debugging.assert_all_finite(w, 'Not all elements of w are finite') if self_normalized: w = w / tf.reduce_mean(w) batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] half_batch = batch_size // 2 float_batch_size = tf.cast(batch_size, float) num_next = tf.cast(tf.round(float_batch_size * rnp), tf.int32) num_future = tf.cast(tf.round(float_batch_size * rfp), tf.int32) if lambda_fix: lambda_coef = 2 * rnp weights = tf.concat([ tf.fill((num_next, ), (1 - gamma)), tf.fill((num_future, ), 1.0), (1 + lambda_coef * gamma * w)[half_batch:] ], axis=0) else: weights = tf.concat([ tf.fill((num_next, ), (1 - gamma)), tf.fill((num_future, ), 1.0), (1 + gamma * w)[half_batch:] ], axis=0) # Note that we assume that episodes never terminate. If they do, then # we need to include next_time_steps.discount in the (negative) TD target. # We exclude the termination here so that we can use termination to # indicate task success during evaluation. In the evaluation setting, # task success depends on the task, but we don't want the termination # here to depend on the task. Hence, we ignored it. if lambda_fix: lambda_coef = 2 * rnp y = lambda_coef * gamma * w / (1 + lambda_coef * gamma * w) else: y = gamma * w / (1 + gamma * w) td_targets = tf.stop_gradient(next_time_steps.reward + (1 - next_time_steps.reward) * y) if rfp > 0: td_targets = tf.concat( [tf.ones(half_batch), td_targets[half_batch:]], axis=0) observation = time_steps.observation pred_input = (observation, actions) pred_td_targets1, _ = self._critic_network_1(pred_input, time_steps.step_type, training=training) pred_td_targets2, _ = self._critic_network_2(pred_input, time_steps.step_type, training=training) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if critic_loss.shape.rank > 1: # Sum over the time dimension. critic_loss = tf.reduce_sum(critic_loss, axis=range(1, critic_loss.shape.rank)) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, sample_weight=weights, regularization_loss=(self._critic_network_1.losses + self._critic_network_2.losses)) critic_loss = agg_loss.total_loss self._critic_loss_debug_summaries(td_targets, pred_td_targets1, pred_td_targets2, weights) return critic_loss
def _critic_loss_with_optional_entropy_term( self, time_steps: ts.TimeStep, actions: types.Tensor, next_time_steps: ts.TimeStep, td_errors_loss_fn: types.LossFn, gamma: types.Float = 1.0, reward_scale_factor: types.Float = 1.0, weights: Optional[types.Tensor] = None, training: bool = False) -> types.Tensor: r"""Computes the critic loss for CQL-SAC training. The original SAC critic loss is: ``` (q(s, a) - (r(s, a) + \gamma q(s', a') - \gamma \alpha \log \pi(a'|s')))^2 ``` The CQL-SAC critic loss makes the entropy term optional. CQL may value unseen actions higher since it lower-bounds the value of seen actions. This makes the policy entropy potentially redundant in the target term, since it will further enhance unseen actions' effects. If self._include_critic_entropy_term is False, this loss equation becomes: ``` (q(s, a) - (r(s, a) + \gamma q(s', a')))^2 ``` Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used for training. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) # We do not update actor or target networks in critic loss. next_actions, next_log_pis = self._actions_and_log_probs( next_time_steps, training=False) target_input = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network_1( target_input, next_time_steps.step_type, training=False) target_q_values2, unused_network_state2 = self._target_critic_network_2( target_input, next_time_steps.step_type, training=False) target_q_values = tf.minimum(target_q_values1, target_q_values2) if self._include_critic_entropy_term: target_q_values -= (tf.exp(self._log_alpha) * next_log_pis) reward = next_time_steps.reward if self._reward_noise_variance > 0: reward_noise = tf.random.normal( tf.shape(reward), 0.0, self._reward_noise_variance, seed=self._reward_seed_stream()) reward += reward_noise td_targets = tf.stop_gradient(reward_scale_factor * reward + gamma * next_time_steps.discount * target_q_values) pred_input = (time_steps.observation, actions) pred_td_targets1, _ = self._critic_network_1(pred_input, time_steps.step_type, training=training) pred_td_targets2, _ = self._critic_network_2(pred_input, time_steps.step_type, training=training) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if critic_loss.shape.rank > 1: # Sum over the time dimension. critic_loss = tf.reduce_sum(critic_loss, axis=range(1, critic_loss.shape.rank)) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, sample_weight=weights, regularization_loss=(self._critic_network_1.losses + self._critic_network_2.losses)) critic_loss = agg_loss.total_loss self._critic_loss_debug_summaries(td_targets, pred_td_targets1, pred_td_targets2) return critic_loss
def __init__(self, nested_layers: types.NestedLayer, input_spec: typing.Optional[types.NestedTensorSpec] = None, name: typing.Optional[typing.Text] = None): """Create a Sequential Network. Args: nested_layers: A nest of layers and/or networks. These will be used to process the inputs (input nest structure will have to match this structure). Any layers that are subclasses of `tf.keras.layers.{RNN,LSTM,GRU,...}` are wrapped in `tf_agents.keras_layers.RNNWrapper`. input_spec: (Optional.) A nest of `tf.TypeSpec` representing the input observations. The structure of `input_spec` must match that of `nested_layers`. name: (Optional.) Network name. Raises: TypeError: If any of the layers are not instances of keras `Layer`. ValueError: If `input_spec` is provided but its nest structure does not match that of `nested_layers`. RuntimeError: If not `tf.executing_eagerly()`; as this is required to be able to create deep copies of layers in `layers`. """ if not tf.executing_eagerly(): raise RuntimeError( 'Not executing eagerly - cannot make deep copies of `nested_layers`.' ) flat_nested_layers = tf.nest.flatten(nested_layers) for layer in flat_nested_layers: if not isinstance(layer, tf.keras.layers.Layer): raise TypeError( 'Expected all layers to be instances of keras Layer, but saw' ': \'{}\''.format(layer)) if input_spec is not None: nest_utils.assert_same_structure( nested_layers, input_spec, message= ('`nested_layers` and `input_spec` do not have matching structures' )) flat_input_spec = tf.nest.flatten(input_spec) else: flat_input_spec = [None] * len(flat_nested_layers) # Wrap in Sequential if necessary. flat_nested_layers = [ sequential.Sequential([m], s) if not isinstance(m, network.Network) else m for (s, m) in zip(flat_input_spec, flat_nested_layers) ] flat_nested_layers_state_specs = [ m.state_spec for m in flat_nested_layers ] nested_layers = tf.nest.pack_sequence_as(nested_layers, flat_nested_layers) # We use flattened layers and states here instead of tf.nest.map_structure # for several reason. One is that we perform several operations against # the layers and we want to avoid calling into tf.nest.map* multiple times. # But the main reason is that network states have a different *structure* # than the layers; e.g., `nested_layers` may just be tf.keras.layers.LSTM, # but the states would then have structure `[.,.]`. Passing these in # as args to tf.nest.map_structure causes it to fail. Instead we would # have to use nest.map_structure_up_to -- but that function is not part # of the public TF API. However, if we do everything in flatland and then # use pack_sequence_as, we bypass the more rigid structure tests. state_spec = tf.nest.pack_sequence_as(nested_layers, flat_nested_layers_state_specs) super(NestMap, self).__init__(input_tensor_spec=input_spec, state_spec=state_spec, name=name) self._nested_layers = nested_layers
def actor_loss(self, time_steps, actions, next_time_steps, weights=None): """Computes the actor_loss for SAC training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: actor_loss: A scalar actor loss. """ prev_time_steps, prev_actions, time_steps = time_steps, actions, next_time_steps # pylint: disable=line-too-long with tf.name_scope('actor_loss'): nest_utils.assert_same_structure(time_steps, self.time_step_spec) actions, log_pi = self._actions_and_log_probs(time_steps) target_input = (time_steps.observation, actions) target_q_values1, _ = self._critic_network_1( target_input, step_type=time_steps.step_type, training=False) target_q_values2, _ = self._critic_network_2( target_input, step_type=time_steps.step_type, training=False) target_q_values = tf.minimum(target_q_values1, target_q_values2) actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values ### Flatten time dimension. We'll add it back when adding the loss. num_outer_dims = nest_utils.get_outer_rank(time_steps, self.time_step_spec) has_time_dim = (num_outer_dims == 2) if has_time_dim: batch_squash = utils.BatchSquash(2) # Squash B, and T dims. obs = batch_squash.flatten(time_steps.observation) prev_obs = batch_squash.flatten(prev_time_steps.observation) prev_actions = batch_squash.flatten(prev_actions) else: obs = time_steps.observation prev_obs = prev_time_steps.observation z = self._actor_network._z_encoder(obs, training=True) # pylint: disable=protected-access prior = self._actor_network._predictor((prev_obs, prev_actions), # pylint: disable=protected-access training=True) # kl is a vector of length batch_size, which has already been summed over # the latent dimension z. kl = tfp.distributions.kl_divergence(z, prior) if has_time_dim: kl = batch_squash.unflatten(kl) kl_coef = tf.stop_gradient( tf.exp(self._actor_network._log_kl_coefficient)) # pylint: disable=protected-access # The actor loss trains both the predictor and the encoder. actor_loss += kl_coef * kl if actor_loss.shape.rank > 1: # Sum over the time dimension. actor_loss = tf.reduce_sum( actor_loss, axis=range(1, actor_loss.shape.rank)) reg_loss = self._actor_network.losses if self._actor_network else None agg_loss = common.aggregate_losses( per_example_loss=actor_loss, sample_weight=weights, regularization_loss=reg_loss) actor_loss = agg_loss.total_loss self._actor_loss_debug_summaries(actor_loss, actions, log_pi, target_q_values, time_steps) tf.compat.v2.summary.scalar( name='encoder_kl', data=tf.reduce_mean(kl), step=self.train_step_counter) return actor_loss
def __call__(self, inputs, *args, **kwargs): """A wrapper around `Network.call`. A typical `call` method in a class subclassing `Network` will have a signature that accepts `inputs`, as well as other `*args` and `**kwargs`. `call` can optionally also accept `step_type` and `network_state` (if `state_spec != ()` is not trivial). e.g.: ```python def call(self, inputs, step_type=None, network_state=(), training=False): ... return outputs, new_network_state ``` We will validate the first argument (`inputs`) against `self.input_tensor_spec` if one is available. If a `network_state` kwarg is given it is also validated against `self.state_spec`. Similarly, the return value of the `call` method is expected to be a tuple/list with 2 values: `(output, new_state)`. We validate `new_state` against `self.state_spec`. If no `network_state` kwarg is given (or if empty `network_state = ()` is given, it is up to `call` to assume a proper "empty" state, and to emit an appropriate `output_state`. Args: inputs: The input to `self.call`, matching `self.input_tensor_spec`. *args: Additional arguments to `self.call`. **kwargs: Additional keyword arguments to `self.call`. These can include `network_state` and `step_type`. `step_type` is required if the network's `call` requires it. `network_state` is required if the underlying network's `call` requires it. Returns: A tuple `(outputs, new_network_state)`. """ if self.input_tensor_spec is not None: nest_utils.assert_same_structure( inputs, self.input_tensor_spec, message="inputs and input_tensor_spec structures do not match") call_argspec = tf_inspect.getargspec(self.call) # Convert *args, **kwargs to a canonical kwarg representation. normalized_kwargs = tf_inspect.getcallargs(self.call, inputs, *args, **kwargs) # TODO(b/156315434): Rename network_state to just state. network_state = normalized_kwargs.get("network_state", None) normalized_kwargs.pop("self", None) if network_state not in (None, ()): nest_utils.assert_same_structure( network_state, self.state_spec, message="network_state and state_spec structures do not match") if "step_type" not in call_argspec.args and not call_argspec.keywords: normalized_kwargs.pop("step_type", None) if (network_state in (None, ()) and "network_state" not in call_argspec.args and not call_argspec.keywords): normalized_kwargs.pop("network_state", None) outputs, new_state = super(Network, self).__call__(**normalized_kwargs) nest_utils.assert_same_structure( new_state, self.state_spec, message= "network output state and state_spec structures do not match") return outputs, new_state