Exemplo n.º 1
0
def _compute_critic_loss(sampled_q_t_all: tf.Tensor, q_tm1_all: tf.Tensor,
                         r_t_all: tf.Tensor, d_t: tf.Tensor, discount: float,
                         num_samples: int, num_critic_heads: int):
    """Compute loss and sampled Q-values for (non-distributional) critics."""
    # Reshape Q-value samples back to original batch dimensions and average
    # them to compute the TD-learning bootstrap target.
    batch_size = r_t_all.get_shape()[0]
    sampled_q_t = tf.reshape(
        sampled_q_t_all,
        (num_samples, batch_size, num_critic_heads))  # [N,B,C]
    q_t = tf.reduce_mean(sampled_q_t, axis=0)  # [B, C]

    # Flatten q_t and q_tm1; necessary for trfl.td_learning
    q_t = tf.reshape(q_t, [-1])  # [B*C]
    q_tm1 = tf.reshape(q_tm1_all, [-1])  # [B*C]

    # Flatten r_t_all; necessary for trfl.td_learning
    r_t_all = tf.reshape(r_t_all, [-1])  # [B*C]

    # Broadcast and then flatten d_t, to match shape of q_t and q_tm1
    d_t = tf.tile(d_t, [num_critic_heads])  # [B*C]
    # Cast the additional discount to match the environment discount dtype.
    discount = tf.cast(discount, dtype=d_t.dtype)

    # Critic loss.
    critic_loss = trfl.td_learning(q_tm1, r_t_all, discount * d_t, q_t).loss
    critic_loss = tf.reduce_mean(critic_loss)
    return critic_loss, sampled_q_t
Exemplo n.º 2
0
  def dev_critic_loss(self, dev_dataset=None):
    critic_loss_sum = 0.
    count = 0.
    for sample in dev_dataset:
      o_tm1, a_tm1, r_t, d_t, o_t = sample.data[:5]

      # Cast the additional discount to match the environment discount dtype.
      discount = tf.cast(self._discount, dtype=d_t.dtype)

      q_t = self._critic_network(o_t, self._policy_network(o_t))
      q_tm1 = self._critic_network(o_tm1, a_tm1)

      # Critic loss.
      if self._distributional:
        critic_loss = losses.categorical(q_tm1, r_t, discount * d_t, q_t)
      else:
        # Squeeze into the shape expected by the td_learning implementation.
        q_tm1 = tf.squeeze(q_tm1, axis=-1)  # [B]
        q_t = tf.squeeze(q_t, axis=-1)  # [B]
        critic_loss = trfl.td_learning(q_tm1, r_t, discount * d_t, q_t).loss

      critic_loss_sum += tf.reduce_mean(critic_loss, axis=[0])
      count += 1.
    return critic_loss_sum / count
Exemplo n.º 3
0
    def _step(self):
        # Update target network.
        online_variables = (
            *self._observation_network.variables,
            *self._critic_network.variables,
            *self._policy_network.variables,
        )
        target_variables = (
            *self._target_observation_network.variables,
            *self._target_critic_network.variables,
            *self._target_policy_network.variables,
        )

        # Make online -> target network update ops.
        if tf.math.mod(self._num_steps, self._target_update_period) == 0:
            for src, dest in zip(online_variables, target_variables):
                dest.assign(src)
        self._num_steps.assign_add(1)

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        inputs = next(self._iterator)
        transitions: types.Transition = inputs.data

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=transitions.discount.dtype)

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(transitions.observation)
            o_t = self._target_observation_network(
                transitions.next_observation)
            # This stop_gradient prevents gradients to propagate into the target
            # observation network. In addition, since the online policy network is
            # evaluated at o_t, this also means the policy loss does not influence
            # the observation network training.
            o_t = tree.map_structure(tf.stop_gradient, o_t)

            # Critic learning.
            q_tm1 = self._critic_network(o_tm1, transitions.action)
            q_t = self._target_critic_network(o_t,
                                              self._target_policy_network(o_t))

            # Squeeze into the shape expected by the td_learning implementation.
            q_tm1 = tf.squeeze(q_tm1, axis=-1)  # [B]
            q_t = tf.squeeze(q_t, axis=-1)  # [B]

            # Critic loss.
            critic_loss = trfl.td_learning(q_tm1, transitions.reward,
                                           discount * transitions.discount,
                                           q_t).loss
            critic_loss = tf.reduce_mean(critic_loss, axis=0)

            # Actor learning.
            dpg_a_t = self._policy_network(o_t)
            dpg_q_t = self._critic_network(o_t, dpg_a_t)

            # Actor loss. If clipping is true use dqda clipping and clip the norm.
            dqda_clipping = 1.0 if self._clipping else None
            policy_loss = losses.dpg(dpg_q_t,
                                     dpg_a_t,
                                     tape=tape,
                                     dqda_clipping=dqda_clipping,
                                     clip_norm=self._clipping)
            policy_loss = tf.reduce_mean(policy_loss, axis=0)

        # Get trainable variables.
        policy_variables = self._policy_network.trainable_variables
        critic_variables = (
            # In this agent, the critic loss trains the observation network.
            self._observation_network.trainable_variables +
            self._critic_network.trainable_variables)

        # Compute gradients.
        policy_gradients = tape.gradient(policy_loss, policy_variables)
        critic_gradients = tape.gradient(critic_loss, critic_variables)

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0]
            critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0]

        # Apply gradients.
        self._policy_optimizer.apply(policy_gradients, policy_variables)
        self._critic_optimizer.apply(critic_gradients, critic_variables)

        # Losses to track.
        return {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }
Exemplo n.º 4
0
    def _step(self) -> types.NestedTensor:
        # Update target network.
        online_policy_variables = self._policy_network.variables
        target_policy_variables = self._target_policy_network.variables
        online_critic_variables = (
            *self._observation_network.variables,
            *self._critic_network.variables,
        )
        target_critic_variables = (
            *self._target_observation_network.variables,
            *self._target_critic_network.variables,
        )

        # Make online policy -> target policy network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_policy_update_period) == 0:
            for src, dest in zip(online_policy_variables,
                                 target_policy_variables):
                dest.assign(src)
        # Make online critic -> target critic network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_critic_update_period) == 0:
            for src, dest in zip(online_critic_variables,
                                 target_critic_variables):
                dest.assign(src)

        self._num_steps.assign_add(1)

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t = inputs.data

        # Get batch size and scalar dtype.
        batch_size = r_t.shape[0]

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=d_t.dtype)

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(o_tm1)
            # This stop_gradient prevents gradients to propagate into the target
            # observation network. In addition, since the online policy network is
            # evaluated at o_t, this also means the policy loss does not influence
            # the observation network training.
            o_t = tf.stop_gradient(self._target_observation_network(o_t))

            # Get online and target action distributions from policy networks.
            online_action_distribution = self._policy_network(o_t)
            target_action_distribution = self._target_policy_network(o_t)

            # Sample actions to evaluate policy; of size [N, B, ...].
            sampled_actions = target_action_distribution.sample(
                self._num_samples)

            # Tile embedded observations to feed into the target critic network.
            # Note: this is more efficient than tiling before the embedding layer.
            tiled_o_t = tf2_utils.tile_tensor(o_t,
                                              self._num_samples)  # [N, B, ...]

            # Compute target-estimated distributional value of sampled actions at o_t.
            sampled_q_t_all = self._target_critic_network(
                # Merge batch dimensions; to shape [N*B, ...].
                snt.merge_leading_dims(tiled_o_t, num_dims=2),
                snt.merge_leading_dims(sampled_actions, num_dims=2))

            # Compute online critic value distribution of a_tm1 in state o_tm1.
            q_tm1_all = self._critic_network(o_tm1, a_tm1)

            # Compute rewards for objectives with defined reward_fn
            reward_stats = {}
            r_t_all = []
            for objective in self._objectives:
                if hasattr(objective, 'reward_fn'):
                    r = objective.reward_fn(o_tm1, a_tm1, r_t)
                    reward_stats['{}_reward'.format(
                        objective.name)] = tf.reduce_mean(r)
                    r_t_all.append(r)
            r_t_all = tf.stack(r_t_all, axis=-1)
            r_t_all.get_shape().assert_has_rank(2)  # [B, C]

            if isinstance(sampled_q_t_all, list):  # Distributional critics
                # Compute average logits by first reshaping them and normalizing them
                # across atoms.
                critic_losses = []
                sampled_q_ts = []
                for idx, (sampled_q_t_distributions,
                          q_tm1_distribution) in enumerate(
                              zip(sampled_q_t_all, q_tm1_all)):
                    # Compute loss for distributional critic for objective c
                    sampled_logits = tf.reshape(
                        sampled_q_t_distributions.logits,
                        [self._num_samples, batch_size, -1])  # [N, B, A]
                    sampled_logprobs = tf.math.log_softmax(sampled_logits,
                                                           axis=-1)
                    averaged_logits = tf.reduce_logsumexp(sampled_logprobs,
                                                          axis=0)

                    # Construct the expected distributional value for bootstrapping.
                    q_t_distribution = networks.DiscreteValuedDistribution(
                        values=sampled_q_t_distributions.values,
                        logits=averaged_logits)

                    # Compute critic distributional loss.
                    critic_loss = losses.categorical(q_tm1_distribution,
                                                     r_t_all[:, idx],
                                                     discount * d_t,
                                                     q_t_distribution)
                    critic_losses.append(tf.reduce_mean(critic_loss))

                    # Compute Q-values of sampled actions and reshape to [N, B].
                    sampled_q_ts.append(
                        tf.reshape(sampled_q_t_distributions.mean(),
                                   (self._num_samples, -1)))

                critic_loss = tf.reduce_mean(critic_losses)
                sampled_q_t = tf.stack(sampled_q_ts, axis=-1)  # [N, B, C]
            else:
                # Reshape Q-value samples back to original batch dimensions and average
                # them to compute the TD-learning bootstrap target.
                sampled_q_t = tf.reshape(sampled_q_t_all,
                                         (self._num_samples, batch_size,
                                          self._num_critic_heads))  # [N,B,C]
                q_t = tf.reduce_mean(sampled_q_t, axis=0)  # [B, C]

                # Flatten q_t and q_tm1; necessary for trfl.td_learning
                q_t = tf.reshape(q_t, [-1])  # [B*C]
                q_tm1 = tf.reshape(q_tm1_all, [-1])  # [B*C]

                # Flatten r_t_all; necessary for trfl.td_learning
                r_t_all = tf.reshape(r_t_all, [-1])  # [B*C]

                # Broadcast and then flatten d_t, to match shape of q_t and q_tm1
                d_t = tf.tile(d_t, [self._num_critic_heads])  # [B*C]

                # Critic loss.
                critic_loss = trfl.td_learning(q_tm1, r_t_all, discount * d_t,
                                               q_t).loss
                critic_loss = tf.reduce_mean(critic_loss)

            # Add sampled Q-values for objectives with defined objective_fn
            sampled_q_idx = 0
            sampled_q_t_k = []
            for objective in self._objectives:
                if hasattr(objective, 'reward_fn'):
                    sampled_q_t_k.append(
                        tf.stop_gradient(sampled_q_t[..., sampled_q_idx]))
                    sampled_q_idx += 1
                if hasattr(objective, 'objective_fn'):
                    sampled_q_t_k.append(
                        tf.stop_gradient(
                            objective.objective_fn(sampled_actions,
                                                   sampled_q_t)))
            sampled_q_t_k = tf.stack(sampled_q_t_k, axis=-1)  # [N, B, K]

            # Compute MPO policy loss.
            policy_loss, policy_stats = self._policy_loss_module(
                online_action_distribution=online_action_distribution,
                target_action_distribution=target_action_distribution,
                actions=sampled_actions,
                q_values=sampled_q_t_k)

        # For clarity, explicitly define which variables are trained by which loss.
        critic_trainable_variables = (
            # In this agent, the critic loss trains the observation network.
            self._observation_network.trainable_variables +
            self._critic_network.trainable_variables)
        policy_trainable_variables = self._policy_network.trainable_variables
        # The following are the MPO dual variables, stored in the loss module.
        dual_trainable_variables = self._policy_loss_module.trainable_variables

        # Compute gradients.
        critic_gradients = tape.gradient(critic_loss,
                                         critic_trainable_variables)
        policy_gradients, dual_gradients = tape.gradient(
            policy_loss,
            (policy_trainable_variables, dual_trainable_variables))

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tuple(
                tf.clip_by_global_norm(policy_gradients, 40.)[0])
            critic_gradients = tuple(
                tf.clip_by_global_norm(critic_gradients, 40.)[0])

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     critic_trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     policy_trainable_variables)
        self._dual_optimizer.apply(dual_gradients, dual_trainable_variables)

        # Losses to track.
        fetches = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }
        fetches.update(policy_stats)  # Log MPO stats.
        fetches.update(reward_stats)  # Log reward stats.

        return fetches
Exemplo n.º 5
0
    def _step(self) -> types.Nest:
        # Update target network.
        online_policy_variables = self._policy_network.variables
        target_policy_variables = self._target_policy_network.variables
        online_critic_variables = (
            *self._observation_network.variables,
            *self._critic_network.variables,
        )
        target_critic_variables = (
            *self._target_observation_network.variables,
            *self._target_critic_network.variables,
        )

        # Make online policy -> target policy network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_policy_update_period) == 0:
            for src, dest in zip(online_policy_variables,
                                 target_policy_variables):
                dest.assign(src)
        # Make online critic -> target critic network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_critic_update_period) == 0:
            for src, dest in zip(online_critic_variables,
                                 target_critic_variables):
                dest.assign(src)

        # Increment number of learner steps for periodic update bookkeeping.
        self._num_steps.assign_add(1)

        # Get next batch of data.
        inputs = next(self._iterator)

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        transitions: types.Transition = inputs.data

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=transitions.discount.dtype)

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(transitions.observation)
            # This stop_gradient prevents gradients to propagate into the target
            # observation network. In addition, since the online policy network is
            # evaluated at o_t, this also means the policy loss does not influence
            # the observation network training.
            o_t = tf.stop_gradient(
                self._target_observation_network(transitions.next_observation))

            # Get action distributions from policy networks.
            online_action_distribution = self._policy_network(o_t)
            target_action_distribution = self._target_policy_network(o_t)

            # Get sampled actions to evaluate policy; of size [N, B, ...].
            sampled_actions = target_action_distribution.sample(
                self._num_samples)
            tiled_o_t = tf2_utils.tile_tensor(o_t,
                                              self._num_samples)  # [N, B, ...]

            # Compute the target critic's Q-value of the sampled actions in state o_t.
            sampled_q_t = self._target_critic_network(
                # Merge batch dimensions; to shape [N*B, ...].
                snt.merge_leading_dims(tiled_o_t, num_dims=2),
                snt.merge_leading_dims(sampled_actions, num_dims=2))

            # Reshape Q-value samples back to original batch dimensions and average
            # them to compute the TD-learning bootstrap target.
            sampled_q_t = tf.reshape(sampled_q_t,
                                     (self._num_samples, -1))  # [N, B]
            q_t = tf.reduce_mean(sampled_q_t, axis=0)  # [B]

            # Compute online critic value of a_tm1 in state o_tm1.
            q_tm1 = self._critic_network(o_tm1, transitions.action)  # [B, 1]
            q_tm1 = tf.squeeze(q_tm1,
                               axis=-1)  # [B]; necessary for trfl.td_learning.

            # Critic loss.
            critic_loss = trfl.td_learning(q_tm1, transitions.reward,
                                           discount * transitions.discount,
                                           q_t).loss
            critic_loss = tf.reduce_mean(critic_loss)

            # Actor learning.
            policy_loss, policy_stats = self._policy_loss_module(
                online_action_distribution=online_action_distribution,
                target_action_distribution=target_action_distribution,
                actions=sampled_actions,
                q_values=sampled_q_t)

        # For clarity, explicitly define which variables are trained by which loss.
        critic_trainable_variables = (
            # In this agent, the critic loss trains the observation network.
            self._observation_network.trainable_variables +
            self._critic_network.trainable_variables)
        policy_trainable_variables = self._policy_network.trainable_variables
        # The following are the MPO dual variables, stored in the loss module.
        dual_trainable_variables = self._policy_loss_module.trainable_variables

        # Compute gradients.
        critic_gradients = tape.gradient(critic_loss,
                                         critic_trainable_variables)
        policy_gradients, dual_gradients = tape.gradient(
            policy_loss,
            (policy_trainable_variables, dual_trainable_variables))

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tuple(
                tf.clip_by_global_norm(policy_gradients, 40.)[0])
            critic_gradients = tuple(
                tf.clip_by_global_norm(critic_gradients, 40.)[0])

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     critic_trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     policy_trainable_variables)
        self._dual_optimizer.apply(dual_gradients, dual_trainable_variables)

        # Losses to track.
        fetches = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }
        fetches.update(policy_stats)  # Log MPO stats.

        return fetches
Exemplo n.º 6
0
  def _step(self) -> Dict[str, tf.Tensor]:
    # Get data from replay (dropping extras if any). Note there is no
    # extra data here because we do not insert any into Reverb.
    sample = next(self._iterator)
    o_tm1, a_tm1, r_t, d_t, o_t = sample.data[:5]

    # Cast the additional discount to match the environment discount dtype.
    discount = tf.cast(self._discount, dtype=d_t.dtype)

    q_t = self._target_critic_network(o_t,
                                      self._policy_network(o_t))
    if not self._distributional and self._vmin is not None:
      q_t = tf.clip_by_value(q_t, self._vmin, self._vmax)
      logging.info('Clip target critic network output with [%f, %f]',
                   self._vmin, self._vmax)

    with tf.GradientTape() as tape:
      # Critic learning.
      q_tm1 = self._critic_network(o_tm1, a_tm1)

      # Critic loss.
      if self._distributional:
        critic_loss = losses.categorical(q_tm1, r_t, discount * d_t, q_t)
      else:
        # Squeeze into the shape expected by the td_learning implementation.
        q_tm1 = tf.squeeze(q_tm1, axis=-1)  # [B]
        q_t = tf.squeeze(q_t, axis=-1)  # [B]
        critic_loss = trfl.td_learning(q_tm1, r_t, discount * d_t, q_t).loss

      critic_loss = tf.reduce_mean(critic_loss, axis=[0])

    # Get trainable variables.
    critic_variables = self._critic_network.trainable_variables

    # Compute gradients.
    critic_gradients = tape.gradient(critic_loss, critic_variables)

    # Maybe clip gradients.
    if self._clipping:
      critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0]

    # Apply gradients.
    self._critic_optimizer.apply(critic_gradients, critic_variables)

    source_variables = self._critic_network.variables
    target_variables = self._target_critic_network.variables

    # Make online -> target network update ops.
    if tf.math.mod(self._num_steps, self._target_update_period) == 0:
      for src, dest in zip(source_variables, target_variables):
        dest.assign(src)

    if self._init_observations is not None:
      if tf.math.mod(self._num_steps, 100) == 0:
        # init_obs = tf.convert_to_tensor(self._init_observations, tf.float32)
        init_obs = tree.map_structure(tf.convert_to_tensor,
                                      self._init_observations)
        init_actions = self._policy_network(init_obs)
        init_critic = tf.reduce_mean(self._critic_mean(init_obs, init_actions))
      else:
        init_critic = tf.constant(0.)
    else:
      init_critic = tf.constant(0.)

    self._num_steps.assign_add(1)

    # Losses to track.
    return {
        'critic_loss': critic_loss,
        'q_s0': init_critic,
    }