Esempio n. 1
0
def get_target_and_main_actions(experience, agents, nameDict, networkDict):
    """MADDPG - get the actions from the target actor network and main actor network of all the agents"""
    total_agents_target_actions = []
    total_agents_main_actions = []
    time_steps, policy_steps, next_time_steps = (
        trajectory.experience_to_transitions(experience,
                                             squeeze_time_dim=True))
    for i, flexAgent in enumerate(agents):
        for node in nameDict:
            target_action = None
            for type, names in nameDict[node].items():
                if flexAgent.id in names:
                    target_action, _ = networkDict[node][
                        type].agent._target_actor_network(
                            next_time_steps.observation[i],
                            next_time_steps.step_type,
                            training=False)
                    main_action, _ = networkDict[node][
                        type].agent._actor_network(time_steps.observation[i],
                                                   time_steps.step_type,
                                                   training=True)
                    break
            if target_action is not None:
                break
        total_agents_target_actions.append(target_action)
        total_agents_main_actions.append(main_action)
    return time_steps, policy_steps, next_time_steps, \
           tuple(total_agents_target_actions), tuple(total_agents_main_actions)
Esempio n. 2
0
 def train(self, experience, agents, nameDict, networkDict):
     time_steps, policy_steps, next_time_steps = (
         trajectory.experience_to_transitions(experience,
                                              squeeze_time_dim=True))
     loss_list = []
     for i, flexAgent in enumerate(agents):
         for node in nameDict:
             for type, names in nameDict[node].items():
                 if flexAgent.id in names:
                     for net in networkDict[node][type]:
                         action_index = -1
                         for t in range(24):
                             action_index += 1
                             actions = tf.gather(policy_steps.action[i],
                                                 indices=action_index,
                                                 axis=-1)
                             individual_iql_time_step = ts.get_individual_iql_time_step(
                                 time_steps, index=i, time=t)
                             individual_iql_next_time_step = ts.get_individual_iql_time_step(
                                 next_time_steps, index=i, time=t)
                             train_loss = self.train_single_net(
                                 net, individual_iql_time_step,
                                 individual_iql_next_time_step, time_steps,
                                 actions, next_time_steps, i, t).loss
                             loss_list.append(train_loss)
                 break
     self.train_step_counter.assign_add(1)
     if self.summary_writer is not None:
         with self.summary_writer.as_default():
             avg_loss = sum(loss_list) / len(loss_list)
             tf.summary.scalar('loss',
                               avg_loss,
                               step=self.train_step_counter)
     return avg_loss
Esempio n. 3
0
 def train(self, experience, agents, nameDict, networkDict):
     """QMIX - get the Q values from the target network and main network of all the agents"""
     time_steps, policy_steps, next_time_steps = (
         trajectory.experience_to_transitions(experience,
                                              squeeze_time_dim=True))
     variables_to_train = getTrainableVariables(networkDict)
     variables_to_train.append(self.QMIXNet.trainable_weights)
     variables_to_train = tf.nest.flatten(variables_to_train)
     assert list(
         variables_to_train), "No variables in the agent's QMIX network."
     with tf.GradientTape(watch_accessed_variables=False) as tape:
         tape.watch(variables_to_train)
         loss_info = self._loss(time_steps,
                                policy_steps,
                                next_time_steps,
                                agents,
                                nameDict,
                                networkDict,
                                td_errors_loss_fn=self._td_errors_loss_fn,
                                gamma=self._gamma,
                                training=True)
     tf.debugging.check_numerics(loss_info.loss, 'Loss is inf or nan')
     grads = tape.gradient(loss_info.loss, variables_to_train)
     grads_and_vars = list(zip(grads, variables_to_train))
     self.train_step_counter = training_lib.apply_gradients(
         self._optimizer,
         grads_and_vars,
         global_step=self.train_step_counter)
     self._update_target()
     return loss_info
Esempio n. 4
0
 def _experience_to_transitions(self, experience):
   boundary_mask = tf.logical_not(experience.is_boundary()[:, 0])
   experience = nest_utils.fast_map_structure(lambda *x: tf.boolean_mask(*x, boundary_mask), experience)
   squeeze_time_dim = not self._critic_network_1.state_spec
   time_steps, policy_steps, next_time_steps = (
       trajectory.experience_to_transitions(experience, squeeze_time_dim))
   return time_steps, policy_steps.action, next_time_steps  #, policy_steps.info
Esempio n. 5
0
def experience_to_transitions(experience):
    boundary_mask = tf.logical_not(experience.is_boundary()[:, 0])
    experience = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, boundary_mask), experience)
    time_steps, policy_steps, next_time_steps = (
        trajectory.experience_to_transitions(experience, True))
    actions = policy_steps.action
    return time_steps, actions, next_time_steps
Esempio n. 6
0
    def _train(self, experience, weights=None):
        # TODO(b/120034503): Move the conversion to transitions to the base class.
        squeeze_time_dim = not self._actor_network.state_spec
        time_steps, policy_steps, next_time_steps = (
            trajectory.experience_to_transitions(experience, squeeze_time_dim))
        actions = policy_steps.action

        trainable_critic_variables = list(
            object_identity.ObjectIdentitySet(
                self._critic_network_1.trainable_variables +
                self._critic_network_2.trainable_variables))
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_critic_variables, (
                'No trainable critic variables to '
                'optimize.')
            tape.watch(trainable_critic_variables)
            critic_loss = self.critic_loss(time_steps,
                                           actions,
                                           next_time_steps,
                                           weights=weights,
                                           training=True)
        tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
        critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
        self._apply_gradients(critic_grads, trainable_critic_variables,
                              self._critic_optimizer)

        trainable_actor_variables = self._actor_network.trainable_variables
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_actor_variables, (
                'No trainable actor variables to '
                'optimize.')
            tape.watch(trainable_actor_variables)
            actor_loss = self.actor_loss(time_steps,
                                         weights=weights,
                                         training=True)
        tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')

        # We only optimize the actor every actor_update_period training steps.
        def optimize_actor():
            actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
            return self._apply_gradients(actor_grads,
                                         trainable_actor_variables,
                                         self._actor_optimizer)

        remainder = tf.math.mod(self.train_step_counter,
                                self._actor_update_period)
        tf.cond(pred=tf.equal(remainder, 0),
                true_fn=optimize_actor,
                false_fn=tf.no_op)

        self.train_step_counter.assign_add(1)
        self._update_target()

        # TODO(b/124382360): Compute per element TD loss and return in loss_info.
        total_loss = actor_loss + critic_loss

        return tf_agent.LossInfo(total_loss, Td3Info(actor_loss, critic_loss))
Esempio n. 7
0
 def _experience_to_sas(self, experience):
     squeeze_time_dim = not self._critic_network_1.state_spec
     (
         time_steps,
         policy_steps,
         next_time_steps,
     ) = trajectory.experience_to_transitions(experience, squeeze_time_dim)
     actions = policy_steps.action
     return tf.concat(
         [time_steps.observation, actions, next_time_steps.observation],
         axis=-1)
Esempio n. 8
0
    def _train(self, experience, weights=None):
        squeeze_time_dim = not self._actor_network.state_spec
        time_steps, policy_steps, next_time_steps = (
            trajectory.experience_to_transitions(experience, squeeze_time_dim))
        actions = policy_steps.action

        # TODO(b/124382524): Apply a loss mask or filter boundary transitions.
        trainable_critic_variables = self._critic_network.trainable_variables
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_critic_variables, (
                'No trainable critic variables to '
                'optimize.')
            tape.watch(trainable_critic_variables)
            critic_loss = self.critic_loss(time_steps,
                                           actions,
                                           next_time_steps,
                                           weights=weights,
                                           training=True)
        tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
        critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
        self._apply_gradients(critic_grads, trainable_critic_variables,
                              self._critic_optimizer)

        trainable_actor_variables = self._actor_network.trainable_variables
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_actor_variables, (
                'No trainable actor variables to '
                'optimize.')
            tape.watch(trainable_actor_variables)
            actor_loss = self.actor_loss(time_steps,
                                         weights=weights,
                                         training=True)
        tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
        actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
        self._apply_gradients(actor_grads, trainable_actor_variables,
                              self._actor_optimizer)

        self.train_step_counter.assign_add(1)
        self._update_target()

        # TODO(b/124382360): Compute per element TD loss and return in loss_info.
        total_loss = actor_loss + critic_loss
        return tf_agent.LossInfo(total_loss, DdpgInfo(actor_loss, critic_loss))
Esempio n. 9
0
def get_target_and_main_values(experience, agents, nameDict, networkDict):
    """QMIX - get the Q values from the target network and main network of all the agents"""
    total_agents_target = []
    total_agents_main = []
    time_steps, policy_steps, next_time_steps = (
        trajectory.experience_to_transitions(experience,
                                             squeeze_time_dim=True))
    for i, flexAgent in enumerate(agents):
        for node in nameDict:
            target = None
            for type, names in nameDict[node].items():
                if flexAgent.id in names:
                    target = []
                    main = []
                    for net in networkDict[node][type]:
                        action_index = -1
                        for t in range(24):
                            action_index += 1
                            actions = tf.gather(policy_steps.action[i],
                                                indices=action_index,
                                                axis=-1)
                            individual_target = net._compute_next_q_values(
                                next_time_steps, index=i, time=t)
                            individual_main = net._compute_q_values(
                                time_steps,
                                actions,
                                index=i,
                                time=t,
                                training=True)
                            target.append(
                                tf.reshape(individual_target, [-1, 1]))
                            main.append(tf.reshape(individual_main, [-1, 1]))
                    break
            if target is not None:
                break
        total_agents_target.append(tf.concat(target, -1))
        total_agents_main.append(tf.concat(main, -1))
    total_agents_target = tf.concat(total_agents_target, -1)
    total_agents_main = tf.concat(total_agents_main, -1)
    return time_steps, policy_steps, next_time_steps, total_agents_target, total_agents_main
Esempio n. 10
0
    def _train(self, experience, weights):
        """Returns a train op to update the agent's networks.

    This method trains with the provided batched experience.

    Args:
      experience: A time-stacked trajectory object.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      A train_op.

    Raises:
      ValueError: If optimizers are None and no default value was provided to
        the constructor.
    """
        squeeze_time_dim = not self._critic_network_1.state_spec
        time_steps, policy_steps, next_time_steps = (
            trajectory.experience_to_transitions(experience, squeeze_time_dim))
        actions = policy_steps.action

        trainable_critic_variables = list(
            object_identity.ObjectIdentitySet(
                self._critic_network_1.trainable_variables +
                self._critic_network_2.trainable_variables))

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_critic_variables, (
                'No trainable critic variables to '
                'optimize.')
            tape.watch(trainable_critic_variables)
            critic_loss = self._critic_loss_weight * self.critic_loss(
                time_steps,
                actions,
                next_time_steps,
                td_errors_loss_fn=self._td_errors_loss_fn,
                gamma=self._gamma,
                reward_scale_factor=self._reward_scale_factor,
                weights=weights,
                training=True)

        tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
        critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
        self._apply_gradients(critic_grads, trainable_critic_variables,
                              self._critic_optimizer)

        critic_no_entropy_loss = None
        if self._critic_network_no_entropy_1 is not None:
            trainable_critic_no_entropy_variables = list(
                object_identity.ObjectIdentitySet(
                    self._critic_network_no_entropy_1.trainable_variables +
                    self._critic_network_no_entropy_2.trainable_variables))
            with tf.GradientTape(watch_accessed_variables=False) as tape:
                assert trainable_critic_no_entropy_variables, (
                    'No trainable critic_no_entropy variables to optimize.')
                tape.watch(trainable_critic_no_entropy_variables)
                critic_no_entropy_loss = self._critic_loss_weight * self.critic_no_entropy_loss(
                    time_steps,
                    actions,
                    next_time_steps,
                    td_errors_loss_fn=self._td_errors_loss_fn,
                    gamma=self._gamma,
                    reward_scale_factor=self._reward_scale_factor,
                    weights=weights,
                    training=True)

            tf.debugging.check_numerics(
                critic_no_entropy_loss,
                'Critic (without entropy) loss is inf or nan.')
            critic_no_entropy_grads = tape.gradient(
                critic_no_entropy_loss, trainable_critic_no_entropy_variables)
            self._apply_gradients(critic_no_entropy_grads,
                                  trainable_critic_no_entropy_variables,
                                  self._critic_no_entropy_optimizer)

        trainable_actor_variables = self._actor_network.trainable_variables
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert trainable_actor_variables, (
                'No trainable actor variables to '
                'optimize.')
            tape.watch(trainable_actor_variables)
            actor_loss = self._actor_loss_weight * self.actor_loss(
                time_steps, weights=weights)
        tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
        actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
        self._apply_gradients(actor_grads, trainable_actor_variables,
                              self._actor_optimizer)

        alpha_variable = [self._log_alpha]
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert alpha_variable, 'No alpha variable to optimize.'
            tape.watch(alpha_variable)
            alpha_loss = self._alpha_loss_weight * self.alpha_loss(
                time_steps, weights=weights)
        tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.')
        alpha_grads = tape.gradient(alpha_loss, alpha_variable)
        self._apply_gradients(alpha_grads, alpha_variable,
                              self._alpha_optimizer)

        with tf.name_scope('Losses'):
            tf.compat.v2.summary.scalar(name='critic_loss_' + self.name,
                                        data=critic_loss,
                                        step=self.train_step_counter)
            tf.compat.v2.summary.scalar(name='actor_loss_' + self.name,
                                        data=actor_loss,
                                        step=self.train_step_counter)
            tf.compat.v2.summary.scalar(name='alpha_loss_' + self.name,
                                        data=alpha_loss,
                                        step=self.train_step_counter)
            if critic_no_entropy_loss is not None:
                tf.compat.v2.summary.scalar(name='critic_no_entropy_loss_' +
                                            self.name,
                                            data=critic_no_entropy_loss,
                                            step=self.train_step_counter)

        self.train_step_counter.assign_add(1)
        self._update_target()

        total_loss = critic_loss + actor_loss + alpha_loss
        if critic_no_entropy_loss is not None:
            total_loss += critic_no_entropy_loss

        extra = SacLossInfo(critic_loss=critic_loss,
                            actor_loss=actor_loss,
                            alpha_loss=alpha_loss,
                            critic_no_entropy_loss=critic_no_entropy_loss)

        return tf_agent.LossInfo(loss=total_loss, extra=extra)
Esempio n. 11
0
    def _loss(self,
              experience,
              td_errors_loss_fn=common.element_wise_huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None,
              training=False):
        """Computes loss for DQN training.

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.train_sequence_length` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute the
        element wise loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output td_loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.
      training: Whether this loss is being used for training.

    Returns:
      loss: An instance of `DqnLossInfo`.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        squeeze_time_dim = not self._q_network.state_spec
        if self._n_step_update == 1:
            time_steps, policy_steps, next_time_steps = (
                trajectory.experience_to_transitions(experience,
                                                     squeeze_time_dim))
            actions = policy_steps.action
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, policy_steps, _ = (
                trajectory.experience_to_transitions(first_two_steps,
                                                     squeeze_time_dim))
            actions = policy_steps.action
            _, _, next_time_steps = (trajectory.experience_to_transitions(
                last_two_steps, squeeze_time_dim))

        with tf.name_scope('loss'):
            q_values = self._compute_q_values(time_steps,
                                              actions,
                                              training=training)

            next_q_values = self._compute_next_q_values(
                next_time_steps, policy_steps.info)

            if self._n_step_update == 1:
                # Special case for n = 1 to avoid a loss of performance.
                td_targets = compute_td_targets(
                    next_q_values,
                    rewards=reward_scale_factor * next_time_steps.reward,
                    discounts=gamma * next_time_steps.discount)
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = reward_scale_factor * experience.reward[:, :-1]
                discounts = gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                td_targets = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=next_q_values,
                    time_major=False,
                    provide_all_returns=False)

            valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
            td_error = valid_mask * (td_targets - q_values)

            td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            # Aggregate across the elements of the batch and add regularization loss.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.

            agg_loss = common.aggregate_losses(
                per_example_loss=td_loss,
                sample_weight=weights,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            losses_dict = {
                'td_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(losses_dict,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._q_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                diff_q_values = q_values - next_q_values
                common.generate_tensor_summaries('td_error', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('q_values', q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_q_values',
                                                 next_q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_q_values',
                                                 diff_q_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
Esempio n. 12
0
    def _loss(self,
              experience,
              td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None,
              training=False):
        """Computes critic loss for CategoricalDQN training.

    See Algorithm 1 and the discussion immediately preceding it in page 6 of
    "A Distributional Perspective on Reinforcement Learning"
      Bellemare et al., 2017
      https://arxiv.org/abs/1707.06887

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.required_experience_time_steps` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional weights used for importance sampling.
      training: Whether the loss is being used for training.
    Returns:
      critic_loss: A scalar critic loss.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        squeeze_time_dim = not self._q_network.state_spec
        if self._n_step_update == 1:
            time_steps, policy_steps, next_time_steps = (
                trajectory.experience_to_transitions(experience,
                                                     squeeze_time_dim))
            actions = policy_steps.action
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, policy_steps, _ = (
                trajectory.experience_to_transitions(first_two_steps,
                                                     squeeze_time_dim))
            actions = policy_steps.action
            _, _, next_time_steps = (trajectory.experience_to_transitions(
                last_two_steps, squeeze_time_dim))

        with tf.name_scope('critic_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            rank = nest_utils.get_outer_rank(time_steps.observation,
                                             self._time_step_spec.observation)

            # If inputs have a time dimension and the q_network is stateful,
            # combine the batch and time dimension.
            batch_squash = (None if rank <= 1 or self._q_network.state_spec
                            in ((), None) else network_utils.BatchSquash(rank))

            network_observation = time_steps.observation

            if self._observation_and_action_constraint_splitter is not None:
                network_observation, _ = (
                    self._observation_and_action_constraint_splitter(
                        network_observation))

            # q_logits contains the Q-value logits for all actions.
            q_logits, _ = self._q_network(network_observation,
                                          time_steps.step_type,
                                          training=training)

            if batch_squash is not None:
                # Squash outer dimensions to a single dimensions for facilitation
                # computing the loss the following. Required for supporting temporal
                # inputs, for example.
                q_logits = batch_squash.flatten(q_logits)
                actions = batch_squash.flatten(actions)
                next_time_steps = tf.nest.map_structure(
                    batch_squash.flatten, next_time_steps)

            next_q_distribution = self._next_q_distribution(next_time_steps)

            if actions.shape.rank > 1:
                actions = tf.squeeze(actions,
                                     list(range(1, actions.shape.rank)))

            # Project the sample Bellman update \hat{T}Z_{\theta} onto the original
            # support of Z_{\theta} (see Figure 1 in paper).
            batch_size = q_logits.shape[0] or tf.shape(q_logits)[0]
            tiled_support = tf.tile(self._support, [batch_size])
            tiled_support = tf.reshape(tiled_support,
                                       [batch_size, self._num_atoms])

            if self._n_step_update == 1:
                discount = next_time_steps.discount
                if discount.shape.rank == 1:
                    # We expect discount to have a shape of [batch_size], while
                    # tiled_support will have a shape of [batch_size, num_atoms]. To
                    # multiply these, we add a second dimension of 1 to the discount.
                    discount = tf.expand_dims(discount, -1)
                next_value_term = tf.multiply(discount,
                                              tiled_support,
                                              name='next_value_term')

                reward = next_time_steps.reward
                if reward.shape.rank == 1:
                    # See the explanation above.
                    reward = tf.expand_dims(reward, -1)
                reward_term = tf.multiply(reward_scale_factor,
                                          reward,
                                          name='reward_term')

                target_support = tf.add(reward_term,
                                        gamma * next_value_term,
                                        name='target_support')
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = reward_scale_factor * experience.reward[:, :-1]
                discounts = gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                discounted_returns = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=tf.zeros([batch_size], dtype=discounts.dtype),
                    time_major=False,
                    provide_all_returns=False)

                # Convert discounted_returns from [batch_size] to [batch_size, 1]
                discounted_returns = tf.expand_dims(discounted_returns, -1)

                final_value_discount = tf.reduce_prod(discounts, axis=1)
                final_value_discount = tf.expand_dims(final_value_discount, -1)

                # Save the values of discounted_returns and final_value_discount in
                # order to check them in unit tests.
                self._discounted_returns = discounted_returns
                self._final_value_discount = final_value_discount

                target_support = tf.add(discounted_returns,
                                        final_value_discount * tiled_support,
                                        name='target_support')

            target_distribution = tf.stop_gradient(
                project_distribution(target_support, next_q_distribution,
                                     self._support))

            # Obtain the current Q-value logits for the selected actions.
            indices = tf.range(batch_size)
            indices = tf.cast(indices, actions.dtype)
            reshaped_actions = tf.stack([indices, actions], axis=-1)
            chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions)

            # Compute the cross-entropy loss between the logits. If inputs have
            # a time dimension, compute the sum over the time dimension before
            # computing the mean over the batch dimension.
            if batch_squash is not None:
                target_distribution = batch_squash.unflatten(
                    target_distribution)
                chosen_action_logits = batch_squash.unflatten(
                    chosen_action_logits)
                critic_loss = tf.reduce_sum(
                    tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
                        labels=target_distribution,
                        logits=chosen_action_logits),
                    axis=1)
            else:
                critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
                    labels=target_distribution, logits=chosen_action_logits)

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            dict_losses = {
                'critic_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(dict_losses,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._debug_summaries:
                distribution_errors = target_distribution - chosen_action_logits
                with tf.name_scope('distribution_errors'):
                    common.generate_tensor_summaries(
                        'distribution_errors',
                        distribution_errors,
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean',
                        tf.reduce_mean(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean_abs',
                        tf.reduce_mean(tf.abs(distribution_errors)),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'max',
                        tf.reduce_max(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'min',
                        tf.reduce_min(distribution_errors),
                        step=self.train_step_counter)
                with tf.name_scope('target_distribution'):
                    common.generate_tensor_summaries(
                        'target_distribution',
                        target_distribution,
                        step=self.train_step_counter)

            # TODO(b/127318640): Give appropriate values for td_loss and td_error for
            # prioritized replay.
            return tf_agent.LossInfo(
                total_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
  def _train(self, experience, weights, augmented_obs=None,
             augmented_next_obs=None):
    """Returns a train op to update the agent's networks.

    This method trains with the provided batched experience.

    Args:
      experience: A time-stacked trajectory object. If augmentations > 1 then a
        tuple of the form: ``` (trajectory, [augmentation_1, ... ,
          augmentation_{K-1}]) ``` is expected.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      augmented_obs: List of length num_augmentations - 1 of random crops of the
        trajectory's observation.
      augmented_next_obs: List of length num_augmentations - 1 of random crops
        of the trajectory's next_observation.

    Returns:
      A train_op.

    Raises:
      ValueError: If optimizers are None and no default value was provided to
        the constructor.
    """
    squeeze_time_dim = not self._critic_network_1.state_spec

    time_steps, policy_steps, next_time_steps = (
        trajectory.experience_to_transitions(experience, squeeze_time_dim))

    actions = policy_steps.action

    trainable_critic_variables = (
        self._critic_network_1.trainable_variables +
        self._critic_network_2.trainable_variables)

    with tf.GradientTape(watch_accessed_variables=False) as tape:
      assert trainable_critic_variables, ('No trainable critic variables to '
                                          'optimize.')
      tape.watch(trainable_critic_variables)

      critic_loss = self._critic_loss_weight * self.critic_loss(
          time_steps,
          actions,
          next_time_steps,
          augmented_obs,
          augmented_next_obs,
          td_errors_loss_fn=self._td_errors_loss_fn,
          gamma=self._gamma,
          reward_scale_factor=self._reward_scale_factor,
          weights=weights,
          training=True)

    tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
    critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
    self._apply_gradients(critic_grads, trainable_critic_variables,
                          self._critic_optimizer)

    total_loss = critic_loss
    actor_loss = tf.constant(0.0, tf.float32)
    alpha_loss = tf.constant(0.0, tf.float32)

    with tf.name_scope('Losses'):
      tf.compat.v2.summary.scalar(
          name='critic_loss', data=critic_loss, step=self.train_step_counter)

    # Only perform actor and alpha updates periodically
    if self.train_step_counter % self._actor_update_frequency == 0:
      trainable_actor_variables = self._actor_network.trainable_variables
      with tf.GradientTape(watch_accessed_variables=False) as tape:
        assert trainable_actor_variables, ('No trainable actor variables to '
                                           'optimize.')
        tape.watch(trainable_actor_variables)
        actor_loss = self._actor_loss_weight * self.actor_loss(
            time_steps, weights=weights)
      tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
      actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
      self._apply_gradients(actor_grads, trainable_actor_variables,
                            self._actor_optimizer)

      alpha_variable = [self._log_alpha]
      with tf.GradientTape(watch_accessed_variables=False) as tape:
        assert alpha_variable, 'No alpha variable to optimize.'
        tape.watch(alpha_variable)
        alpha_loss = self._alpha_loss_weight * self.alpha_loss(
            time_steps, weights=weights)
      tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.')
      alpha_grads = tape.gradient(alpha_loss, alpha_variable)
      self._apply_gradients(alpha_grads, alpha_variable, self._alpha_optimizer)

      with tf.name_scope('Losses'):
        tf.compat.v2.summary.scalar(
            name='actor_loss', data=actor_loss, step=self.train_step_counter)
        tf.compat.v2.summary.scalar(
            name='alpha_loss', data=alpha_loss, step=self.train_step_counter)

      total_loss = critic_loss + actor_loss + alpha_loss

    self.train_step_counter.assign_add(1)
    self._update_target()

    # NOTE: Consider keeping track of previous actor/alpha loss.
    extra = sac_agent.SacLossInfo(
        critic_loss=critic_loss, actor_loss=actor_loss, alpha_loss=alpha_loss)

    return tf_agent.LossInfo(loss=total_loss, extra=extra)