Exemple #1
0
def _compute_distributional_critic_loss(
    sampled_q_t_all: List[tf.Tensor],
    q_tm1_all: List[tf.Tensor],
    r_t_all: tf.Tensor,
    d_t: tf.Tensor,
    discount: float,
    num_samples: int):
  """Compute loss and sampled Q-values for distributional critics."""
  # Compute average logits by first reshaping them and normalizing them
  # across atoms.
  batch_size = r_t_all.get_shape()[0]
  # Cast the additional discount to match the environment discount dtype.
  discount = tf.cast(discount, dtype=d_t.dtype)
  critic_losses = []
  sampled_q_ts = []
  for idx, (sampled_q_t_distributions, q_tm1_distribution) in enumerate(
      zip(sampled_q_t_all, q_tm1_all)):
    # Compute loss for distributional critic for objective c
    sampled_logits = tf.reshape(
        sampled_q_t_distributions.logits,
        [num_samples, batch_size, -1])  # [N, B, A]
    sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1)
    averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0)

    # Construct the expected distributional value for bootstrapping.
    q_t_distribution = networks.DiscreteValuedDistribution(
        values=sampled_q_t_distributions.values, logits=averaged_logits)

    # Compute critic distributional loss.
    critic_loss = losses.categorical(
        q_tm1_distribution, r_t_all[:, idx], discount * d_t,
        q_t_distribution)
    critic_losses.append(tf.reduce_mean(critic_loss))

    # Compute Q-values of sampled actions and reshape to [N, B].
    sampled_q_ts.append(tf.reshape(
        sampled_q_t_distributions.mean(), (num_samples, -1)))

  critic_loss = tf.reduce_mean(critic_losses)
  sampled_q_t = tf.stack(sampled_q_ts, axis=-1)  # [N, B, C]
  return critic_loss, sampled_q_t
Exemple #2
0
  def dev_critic_loss(self, dev_dataset=None):
    critic_loss_sum = 0.
    count = 0.
    for sample in dev_dataset:
      o_tm1, a_tm1, r_t, d_t, o_t = sample.data[:5]

      # Cast the additional discount to match the environment discount dtype.
      discount = tf.cast(self._discount, dtype=d_t.dtype)

      q_t = self._critic_network(o_t, self._policy_network(o_t))
      q_tm1 = self._critic_network(o_tm1, a_tm1)

      # Critic loss.
      if self._distributional:
        critic_loss = losses.categorical(q_tm1, r_t, discount * d_t, q_t)
      else:
        # Squeeze into the shape expected by the td_learning implementation.
        q_tm1 = tf.squeeze(q_tm1, axis=-1)  # [B]
        q_t = tf.squeeze(q_t, axis=-1)  # [B]
        critic_loss = trfl.td_learning(q_tm1, r_t, discount * d_t, q_t).loss

      critic_loss_sum += tf.reduce_mean(critic_loss, axis=[0])
      count += 1.
    return critic_loss_sum / count
Exemple #3
0
    def _step(self) -> types.NestedTensor:
        # Update target network.
        online_policy_variables = self._policy_network.variables
        target_policy_variables = self._target_policy_network.variables
        online_critic_variables = (
            *self._observation_network.variables,
            *self._critic_network.variables,
        )
        target_critic_variables = (
            *self._target_observation_network.variables,
            *self._target_critic_network.variables,
        )

        # Make online policy -> target policy network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_policy_update_period) == 0:
            for src, dest in zip(online_policy_variables,
                                 target_policy_variables):
                dest.assign(src)
        # Make online critic -> target critic network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_critic_update_period) == 0:
            for src, dest in zip(online_critic_variables,
                                 target_critic_variables):
                dest.assign(src)

        self._num_steps.assign_add(1)

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t = inputs.data

        # Get batch size and scalar dtype.
        batch_size = r_t.shape[0]

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=d_t.dtype)

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(o_tm1)
            # This stop_gradient prevents gradients to propagate into the target
            # observation network. In addition, since the online policy network is
            # evaluated at o_t, this also means the policy loss does not influence
            # the observation network training.
            o_t = tf.stop_gradient(self._target_observation_network(o_t))

            # Get online and target action distributions from policy networks.
            online_action_distribution = self._policy_network(o_t)
            target_action_distribution = self._target_policy_network(o_t)

            # Sample actions to evaluate policy; of size [N, B, ...].
            sampled_actions = target_action_distribution.sample(
                self._num_samples)

            # Tile embedded observations to feed into the target critic network.
            # Note: this is more efficient than tiling before the embedding layer.
            tiled_o_t = tf2_utils.tile_tensor(o_t,
                                              self._num_samples)  # [N, B, ...]

            # Compute target-estimated distributional value of sampled actions at o_t.
            sampled_q_t_distributions = self._target_critic_network(
                # Merge batch dimensions; to shape [N*B, ...].
                snt.merge_leading_dims(tiled_o_t, num_dims=2),
                snt.merge_leading_dims(sampled_actions, num_dims=2))

            # Compute average logits by first reshaping them and normalizing them
            # across atoms.
            new_shape = [self._num_samples, batch_size, -1]  # [N, B, A]
            sampled_logits = tf.reshape(sampled_q_t_distributions.logits,
                                        new_shape)
            sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1)
            averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0)

            # Construct the expected distributional value for bootstrapping.
            q_t_distribution = networks.DiscreteValuedDistribution(
                values=sampled_q_t_distributions.values,
                logits=averaged_logits)

            # Compute online critic value distribution of a_tm1 in state o_tm1.
            q_tm1_distribution = self._critic_network(o_tm1, a_tm1)

            # Compute critic distributional loss.
            critic_loss = losses.categorical(q_tm1_distribution, r_t,
                                             discount * d_t, q_t_distribution)
            critic_loss = tf.reduce_mean(critic_loss)

            # Compute Q-values of sampled actions and reshape to [N, B].
            sampled_q_values = sampled_q_t_distributions.mean()
            sampled_q_values = tf.reshape(sampled_q_values,
                                          (self._num_samples, -1))

            # Compute MPO policy loss.
            policy_loss, policy_stats = self._policy_loss_module(
                online_action_distribution=online_action_distribution,
                target_action_distribution=target_action_distribution,
                actions=sampled_actions,
                q_values=sampled_q_values)

        # For clarity, explicitly define which variables are trained by which loss.
        critic_trainable_variables = (
            # In this agent, the critic loss trains the observation network.
            self._observation_network.trainable_variables +
            self._critic_network.trainable_variables)
        policy_trainable_variables = self._policy_network.trainable_variables
        # The following are the MPO dual variables, stored in the loss module.
        dual_trainable_variables = self._policy_loss_module.trainable_variables

        # Compute gradients.
        critic_gradients = tape.gradient(critic_loss,
                                         critic_trainable_variables)
        policy_gradients, dual_gradients = tape.gradient(
            policy_loss,
            (policy_trainable_variables, dual_trainable_variables))

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tuple(
                tf.clip_by_global_norm(policy_gradients, 40.)[0])
            critic_gradients = tuple(
                tf.clip_by_global_norm(critic_gradients, 40.)[0])

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     critic_trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     policy_trainable_variables)
        self._dual_optimizer.apply(dual_gradients, dual_trainable_variables)

        # Losses to track.
        fetches = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }
        fetches.update(policy_stats)  # Log MPO stats.

        return fetches
Exemple #4
0
    def _step(self, sample: reverb.ReplaySample) -> Dict[str, tf.Tensor]:
        # Transpose batch and sequence axes, i.e. [B, T, ...] to [T, B, ...].
        sample = tf2_utils.batch_to_sequence(sample)
        observations = sample.observation
        actions = sample.action
        rewards = sample.reward
        discounts = sample.discount

        dtype = rewards.dtype

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=discounts.dtype)

        # Loss cumulants across time. These cannot be python mutable objects.
        critic_loss = 0.
        policy_loss = 0.

        # Each transition induces a policy loss, which we then weight using
        # the `policy_loss_coef_t`; shape [B], see https://arxiv.org/abs/2006.15134.
        # `policy_loss_coef` is a scalar average of these coefficients across
        # the batch and sequence length dimensions.
        policy_loss_coef = 0.

        per_device_batch_size = actions.shape[1]

        # Initialize recurrent states.
        critic_state = self._critic_network.initial_state(
            per_device_batch_size)
        target_critic_state = critic_state
        policy_state = self._policy_network.initial_state(
            per_device_batch_size)
        target_policy_state = policy_state

        with tf.GradientTape(persistent=True) as tape:
            for t in range(1, self._sequence_length):
                o_tm1 = tree.map_structure(operator.itemgetter(t - 1),
                                           observations)
                a_tm1 = tree.map_structure(operator.itemgetter(t - 1), actions)
                r_t = tree.map_structure(operator.itemgetter(t - 1), rewards)
                d_t = tree.map_structure(operator.itemgetter(t - 1), discounts)
                o_t = tree.map_structure(operator.itemgetter(t), observations)

                if t != 1:
                    # By only updating the target critic state here we are forcing
                    # the target critic to ignore observations[0]. Otherwise, the
                    # target_critic will be unrolled for one more timestep than critic.
                    # The smaller the sequence length, the more problematic this is: if
                    # you use RNN on sequences of length 2, you would expect the code to
                    # never use recurrent connections. But if you don't skip updating the
                    # target_critic_state on observation[0] here, it won't be the case.
                    _, target_critic_state = self._target_critic_network(
                        o_tm1, a_tm1, target_critic_state)

                # ========================= Critic learning ============================
                q_tm1, next_critic_state = self._critic_network(
                    o_tm1, a_tm1, critic_state)
                target_action_distribution, target_policy_state = self._target_policy_network(
                    o_t, target_policy_state)

                sampled_actions_t = target_action_distribution.sample(
                    self._num_action_samples_td_learning)
                # [N, B, ...]
                tiled_o_t = tf2_utils.tile_nested(
                    o_t, self._num_action_samples_td_learning)
                tiled_target_critic_state = tf2_utils.tile_nested(
                    target_critic_state, self._num_action_samples_td_learning)

                # Compute the target critic's Q-value of the sampled actions.
                sampled_q_t, _ = snt.BatchApply(self._target_critic_network)(
                    tiled_o_t, sampled_actions_t, tiled_target_critic_state)

                # Compute average logits by first reshaping them to [N, B, A] and then
                # normalizing them across atoms.
                new_shape = [
                    self._num_action_samples_td_learning, r_t.shape[0], -1
                ]
                sampled_logits = tf.reshape(sampled_q_t.logits, new_shape)
                sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1)
                averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0)

                # Construct the expected distributional value for bootstrapping.
                q_t = networks.DiscreteValuedDistribution(
                    values=sampled_q_t.values, logits=averaged_logits)
                critic_loss_t = losses.categorical(q_tm1, r_t, discount * d_t,
                                                   q_t)
                critic_loss_t = tf.reduce_mean(critic_loss_t)

                # ========================= Actor learning =============================
                action_distribution_tm1, policy_state = self._policy_network(
                    o_tm1, policy_state)
                q_tm1_mean = q_tm1.mean()

                # Compute the estimate of the value function based on
                # self._num_action_samples_policy_weight samples from the policy.
                tiled_o_tm1 = tf2_utils.tile_nested(
                    o_tm1, self._num_action_samples_policy_weight)
                tiled_critic_state = tf2_utils.tile_nested(
                    critic_state, self._num_action_samples_policy_weight)
                action_tm1 = action_distribution_tm1.sample(
                    self._num_action_samples_policy_weight)
                tiled_z_tm1, _ = snt.BatchApply(self._critic_network)(
                    tiled_o_tm1, action_tm1, tiled_critic_state)
                tiled_v_tm1 = tf.reshape(
                    tiled_z_tm1.mean(),
                    [self._num_action_samples_policy_weight, -1])

                # Use mean, min, or max to aggregate Q(s, a_i), a_i ~ pi(s) into the
                # final estimate of the value function.
                if self._baseline_reduce_function == 'mean':
                    v_tm1_estimate = tf.reduce_mean(tiled_v_tm1, axis=0)
                elif self._baseline_reduce_function == 'max':
                    v_tm1_estimate = tf.reduce_max(tiled_v_tm1, axis=0)
                elif self._baseline_reduce_function == 'min':
                    v_tm1_estimate = tf.reduce_min(tiled_v_tm1, axis=0)

                # Assert that action_distribution_tm1 is a batch of multivariate
                # distributions (in contrast to e.g. a [batch, action_size] collection
                # of 1d distributions).
                assert len(action_distribution_tm1.batch_shape) == 1
                policy_loss_batch = -action_distribution_tm1.log_prob(a_tm1)

                advantage = q_tm1_mean - v_tm1_estimate
                if self._policy_improvement_modes == 'exp':
                    policy_loss_coef_t = tf.math.minimum(
                        tf.math.exp(advantage / self._beta),
                        self._ratio_upper_bound)
                elif self._policy_improvement_modes == 'binary':
                    policy_loss_coef_t = tf.cast(advantage > 0, dtype=dtype)
                elif self._policy_improvement_modes == 'all':
                    # Regress against all actions (effectively pure BC).
                    policy_loss_coef_t = 1.
                policy_loss_coef_t = tf.stop_gradient(policy_loss_coef_t)

                policy_loss_batch *= policy_loss_coef_t
                policy_loss_t = tf.reduce_mean(policy_loss_batch)

                critic_state = next_critic_state

                critic_loss += critic_loss_t
                policy_loss += policy_loss_t
                policy_loss_coef += tf.reduce_mean(
                    policy_loss_coef_t)  # For logging.

            # Divide by sequence length to get mean losses.
            critic_loss /= tf.cast(self._sequence_length, dtype=dtype)
            policy_loss /= tf.cast(self._sequence_length, dtype=dtype)
            policy_loss_coef /= tf.cast(self._sequence_length, dtype=dtype)

        # Compute gradients.
        critic_gradients = tape.gradient(
            critic_loss, self._critic_network.trainable_variables)
        policy_gradients = tape.gradient(
            policy_loss, self._policy_network.trainable_variables)

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Sync gradients across GPUs or TPUs.
        ctx = tf.distribute.get_replica_context()
        critic_gradients = ctx.all_reduce('mean', critic_gradients)
        policy_gradients = ctx.all_reduce('mean', policy_gradients)

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0]
            critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0]

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     self._critic_network.trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     self._policy_network.trainable_variables)

        source_variables = (self._critic_network.variables +
                            self._policy_network.variables)
        target_variables = (self._target_critic_network.variables +
                            self._target_policy_network.variables)

        # Make online -> target network update ops.
        if tf.math.mod(self._num_steps, self._target_update_period) == 0:
            for src, dest in zip(source_variables, target_variables):
                dest.assign(src)
        self._num_steps.assign_add(1)

        return {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
            'policy_loss_coef': policy_loss_coef,
        }
Exemple #5
0
  def _step(self) -> Dict[str, tf.Tensor]:
    # Update target network
    online_variables = (
        *self._observation_network.variables,
        *self._critic_network.variables,
        *self._policy_network.variables,
    )
    target_variables = (
        *self._target_observation_network.variables,
        *self._target_critic_network.variables,
        *self._target_policy_network.variables,
    )

    # Make online -> target network update ops.
    if tf.math.mod(self._num_steps, self._target_update_period) == 0:
      for src, dest in zip(online_variables, target_variables):
        dest.assign(src)
    self._num_steps.assign_add(1)

    # Get data from replay (dropping extras if any). Note there is no
    # extra data here because we do not insert any into Reverb.
    sample = next(self._iterator)
    o_tm1, a_tm1, r_t, d_t, o_t = sample.data  # Assuming ReverbSample.

    # Cast the additional discount to match the environment discount dtype.
    discount = tf.cast(self._discount, dtype=d_t.dtype)

    with tf.GradientTape(persistent=True) as tape:
      # Maybe transform the observation before feeding into policy and critic.
      # Transforming the observations this way at the start of the learning
      # step effectively means that the policy and critic share observation
      # network weights.
      o_tm1 = self._observation_network(o_tm1)
      o_t = self._target_observation_network(o_t)
      # This stop_gradient prevents gradients to propagate into the target
      # observation network. In addition, since the online policy network is
      # evaluated at o_t, this also means the policy loss does not influence
      # the observation network training.
      o_t = tree.map_structure(tf.stop_gradient, o_t)

      # Critic learning.
      q_tm1 = self._critic_network(o_tm1, a_tm1)
      q_t = self._target_critic_network(o_t, self._target_policy_network(o_t))

      # Critic loss.
      critic_loss = losses.categorical(q_tm1, r_t, discount * d_t, q_t)
      critic_loss = tf.reduce_mean(critic_loss, axis=[0])

      # Actor learning.
      dpg_a_t = self._policy_network(o_t)
      dpg_z_t = self._critic_network(o_t, dpg_a_t)
      dpg_q_t = dpg_z_t.mean()

      # Actor loss. If clipping is true use dqda clipping and clip the norm.
      dqda_clipping = 1.0 if self._clipping else None
      policy_loss = losses.dpg(
          dpg_q_t,
          dpg_a_t,
          tape=tape,
          dqda_clipping=dqda_clipping,
          clip_norm=self._clipping)
      policy_loss = tf.reduce_mean(policy_loss, axis=[0])

    # Get trainable variables.
    policy_variables = self._policy_network.trainable_variables
    critic_variables = (
        # In this agent, the critic loss trains the observation network.
        self._observation_network.trainable_variables +
        self._critic_network.trainable_variables)

    # Compute gradients.
    policy_gradients = tape.gradient(policy_loss, policy_variables)
    critic_gradients = tape.gradient(critic_loss, critic_variables)

    # Delete the tape manually because of the persistent=True flag.
    del tape

    # Maybe clip gradients.
    if self._clipping:
      policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0]
      critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0]

    # Apply gradients.
    self._policy_optimizer.apply(policy_gradients, policy_variables)
    self._critic_optimizer.apply(critic_gradients, critic_variables)

    # Losses to track.
    return {
        'critic_loss': critic_loss,
        'policy_loss': policy_loss,
    }
Exemple #6
0
    def _forward(self, inputs: Any) -> None:
        """Trainer forward pass

        Args:
            inputs (Any): input data from the data table (transitions)
        """

        # Unpack input data as follows:
        # o_tm1 = dictionary of observations one for each agent
        # a_tm1 = dictionary of actions taken from obs in o_tm1
        # e_tm1 [Optional] = extra data for timestep t-1
        # that the agents persist in replay.
        # r_t = dictionary of rewards or rewards sequences
        #   (if using N step transitions) ensuing from actions a_tm1
        # d_t = environment discount ensuing from actions a_tm1.
        #   This discount is applied to future rewards after r_t.
        # o_t = dictionary of next observations or next observation sequences
        # e_t [Optional] = extra data for timestep t that the agents persist in replay.
        o_tm1, a_tm1, e_tm1, r_t, d_t, o_t, e_t = inputs.data

        # Do forward passes through the networks and calculate the losses
        self.policy_losses = {}
        self.critic_losses = {}
        with tf.GradientTape(persistent=True) as tape:
            o_tm1_trans, o_t_trans = self._transform_observations(o_tm1, o_t)
            a_t = self._target_policy_actions(o_t_trans)

            for agent in self._agents:
                agent_key = self.agent_net_keys[agent]

                # Get critic feed
                o_tm1_feed, o_t_feed, a_tm1_feed, a_t_feed = self._get_critic_feed(
                    o_tm1_trans=o_tm1_trans,
                    o_t_trans=o_t_trans,
                    a_tm1=a_tm1,
                    a_t=a_t,
                    e_tm1=e_tm1,
                    e_t=e_t,
                    agent=agent,
                )

                # Critic learning.
                q_tm1 = self._critic_networks[agent_key](o_tm1_feed,
                                                         a_tm1_feed)
                q_t = self._target_critic_networks[agent_key](o_t_feed,
                                                              a_t_feed)

                # Cast the additional discount to match the environment discount dtype.
                discount = tf.cast(self._discount, dtype=d_t[agent].dtype)

                # Critic loss.
                critic_loss = losses.categorical(q_tm1, r_t[agent],
                                                 discount * d_t[agent], q_t)
                self.critic_losses[agent] = tf.reduce_mean(critic_loss, axis=0)
                # Actor learning.
                o_t_agent_feed = o_t_trans[agent]
                dpg_a_t = self._policy_networks[agent_key](o_t_agent_feed)

                # Get dpg actions
                dpg_a_t_feed = self._get_dpg_feed(a_t, dpg_a_t, agent)

                # Get dpg Q values.
                dpg_z_t = self._critic_networks[agent_key](o_t_feed,
                                                           dpg_a_t_feed)
                dpg_q_t = dpg_z_t.mean()

                # Actor loss. If clipping is true use dqda clipping and clip the norm.
                dqda_clipping = 1.0 if self._max_gradient_norm is not None else None
                clip_norm = True if self._max_gradient_norm is not None else False

                policy_loss = losses.dpg(
                    dpg_q_t,
                    dpg_a_t,
                    tape=tape,
                    dqda_clipping=dqda_clipping,
                    clip_norm=clip_norm,
                )
                self.policy_losses[agent] = tf.reduce_mean(policy_loss, axis=0)
        self.tape = tape
Exemple #7
0
    def _step(self) -> types.NestedTensor:
        # Update target network.
        online_policy_variables = self._policy_network.variables
        target_policy_variables = self._target_policy_network.variables
        online_critic_variables = (
            *self._observation_network.variables,
            *self._critic_network.variables,
        )
        target_critic_variables = (
            *self._target_observation_network.variables,
            *self._target_critic_network.variables,
        )

        # Make online policy -> target policy network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_policy_update_period) == 0:
            for src, dest in zip(online_policy_variables,
                                 target_policy_variables):
                dest.assign(src)
        # Make online critic -> target critic network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_critic_update_period) == 0:
            for src, dest in zip(online_critic_variables,
                                 target_critic_variables):
                dest.assign(src)

        self._num_steps.assign_add(1)

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t = inputs.data

        # Get batch size and scalar dtype.
        batch_size = r_t.shape[0]

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=d_t.dtype)

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(o_tm1)
            # This stop_gradient prevents gradients to propagate into the target
            # observation network. In addition, since the online policy network is
            # evaluated at o_t, this also means the policy loss does not influence
            # the observation network training.
            o_t = tf.stop_gradient(self._target_observation_network(o_t))

            # Get online and target action distributions from policy networks.
            online_action_distribution = self._policy_network(o_t)
            target_action_distribution = self._target_policy_network(o_t)

            # Sample actions to evaluate policy; of size [N, B, ...].
            sampled_actions = target_action_distribution.sample(
                self._num_samples)

            # Tile embedded observations to feed into the target critic network.
            # Note: this is more efficient than tiling before the embedding layer.
            tiled_o_t = tf2_utils.tile_tensor(o_t,
                                              self._num_samples)  # [N, B, ...]

            # Compute target-estimated distributional value of sampled actions at o_t.
            sampled_q_t_all = self._target_critic_network(
                # Merge batch dimensions; to shape [N*B, ...].
                snt.merge_leading_dims(tiled_o_t, num_dims=2),
                snt.merge_leading_dims(sampled_actions, num_dims=2))

            # Compute online critic value distribution of a_tm1 in state o_tm1.
            q_tm1_all = self._critic_network(o_tm1, a_tm1)

            # Compute rewards for objectives with defined reward_fn
            reward_stats = {}
            r_t_all = []
            for objective in self._objectives:
                if hasattr(objective, 'reward_fn'):
                    r = objective.reward_fn(o_tm1, a_tm1, r_t)
                    reward_stats['{}_reward'.format(
                        objective.name)] = tf.reduce_mean(r)
                    r_t_all.append(r)
            r_t_all = tf.stack(r_t_all, axis=-1)
            r_t_all.get_shape().assert_has_rank(2)  # [B, C]

            if isinstance(sampled_q_t_all, list):  # Distributional critics
                # Compute average logits by first reshaping them and normalizing them
                # across atoms.
                critic_losses = []
                sampled_q_ts = []
                for idx, (sampled_q_t_distributions,
                          q_tm1_distribution) in enumerate(
                              zip(sampled_q_t_all, q_tm1_all)):
                    # Compute loss for distributional critic for objective c
                    sampled_logits = tf.reshape(
                        sampled_q_t_distributions.logits,
                        [self._num_samples, batch_size, -1])  # [N, B, A]
                    sampled_logprobs = tf.math.log_softmax(sampled_logits,
                                                           axis=-1)
                    averaged_logits = tf.reduce_logsumexp(sampled_logprobs,
                                                          axis=0)

                    # Construct the expected distributional value for bootstrapping.
                    q_t_distribution = networks.DiscreteValuedDistribution(
                        values=sampled_q_t_distributions.values,
                        logits=averaged_logits)

                    # Compute critic distributional loss.
                    critic_loss = losses.categorical(q_tm1_distribution,
                                                     r_t_all[:, idx],
                                                     discount * d_t,
                                                     q_t_distribution)
                    critic_losses.append(tf.reduce_mean(critic_loss))

                    # Compute Q-values of sampled actions and reshape to [N, B].
                    sampled_q_ts.append(
                        tf.reshape(sampled_q_t_distributions.mean(),
                                   (self._num_samples, -1)))

                critic_loss = tf.reduce_mean(critic_losses)
                sampled_q_t = tf.stack(sampled_q_ts, axis=-1)  # [N, B, C]
            else:
                # Reshape Q-value samples back to original batch dimensions and average
                # them to compute the TD-learning bootstrap target.
                sampled_q_t = tf.reshape(sampled_q_t_all,
                                         (self._num_samples, batch_size,
                                          self._num_critic_heads))  # [N,B,C]
                q_t = tf.reduce_mean(sampled_q_t, axis=0)  # [B, C]

                # Flatten q_t and q_tm1; necessary for trfl.td_learning
                q_t = tf.reshape(q_t, [-1])  # [B*C]
                q_tm1 = tf.reshape(q_tm1_all, [-1])  # [B*C]

                # Flatten r_t_all; necessary for trfl.td_learning
                r_t_all = tf.reshape(r_t_all, [-1])  # [B*C]

                # Broadcast and then flatten d_t, to match shape of q_t and q_tm1
                d_t = tf.tile(d_t, [self._num_critic_heads])  # [B*C]

                # Critic loss.
                critic_loss = trfl.td_learning(q_tm1, r_t_all, discount * d_t,
                                               q_t).loss
                critic_loss = tf.reduce_mean(critic_loss)

            # Add sampled Q-values for objectives with defined objective_fn
            sampled_q_idx = 0
            sampled_q_t_k = []
            for objective in self._objectives:
                if hasattr(objective, 'reward_fn'):
                    sampled_q_t_k.append(
                        tf.stop_gradient(sampled_q_t[..., sampled_q_idx]))
                    sampled_q_idx += 1
                if hasattr(objective, 'objective_fn'):
                    sampled_q_t_k.append(
                        tf.stop_gradient(
                            objective.objective_fn(sampled_actions,
                                                   sampled_q_t)))
            sampled_q_t_k = tf.stack(sampled_q_t_k, axis=-1)  # [N, B, K]

            # Compute MPO policy loss.
            policy_loss, policy_stats = self._policy_loss_module(
                online_action_distribution=online_action_distribution,
                target_action_distribution=target_action_distribution,
                actions=sampled_actions,
                q_values=sampled_q_t_k)

        # For clarity, explicitly define which variables are trained by which loss.
        critic_trainable_variables = (
            # In this agent, the critic loss trains the observation network.
            self._observation_network.trainable_variables +
            self._critic_network.trainable_variables)
        policy_trainable_variables = self._policy_network.trainable_variables
        # The following are the MPO dual variables, stored in the loss module.
        dual_trainable_variables = self._policy_loss_module.trainable_variables

        # Compute gradients.
        critic_gradients = tape.gradient(critic_loss,
                                         critic_trainable_variables)
        policy_gradients, dual_gradients = tape.gradient(
            policy_loss,
            (policy_trainable_variables, dual_trainable_variables))

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tuple(
                tf.clip_by_global_norm(policy_gradients, 40.)[0])
            critic_gradients = tuple(
                tf.clip_by_global_norm(critic_gradients, 40.)[0])

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     critic_trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     policy_trainable_variables)
        self._dual_optimizer.apply(dual_gradients, dual_trainable_variables)

        # Losses to track.
        fetches = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }
        fetches.update(policy_stats)  # Log MPO stats.
        fetches.update(reward_stats)  # Log reward stats.

        return fetches
Exemple #8
0
  def _step(self, sample) -> Dict[str, tf.Tensor]:
    transitions: types.Transition = sample.data  # Assuming ReverbSample.

    # Cast the additional discount to match the environment discount dtype.
    discount = tf.cast(self._discount, dtype=transitions.discount.dtype)

    with tf.GradientTape(persistent=True) as tape:
      # Maybe transform the observation before feeding into policy and critic.
      # Transforming the observations this way at the start of the learning
      # step effectively means that the policy and critic share observation
      # network weights.
      o_tm1 = self._observation_network(transitions.observation)
      o_t = self._target_observation_network(transitions.next_observation)
      # This stop_gradient prevents gradients to propagate into the target
      # observation network. In addition, since the online policy network is
      # evaluated at o_t, this also means the policy loss does not influence
      # the observation network training.
      o_t = tree.map_structure(tf.stop_gradient, o_t)

      # Critic learning.
      q_tm1 = self._critic_network(o_tm1, transitions.action)
      q_t = self._target_critic_network(o_t, self._target_policy_network(o_t))

      # Critic loss.
      critic_loss = losses.categorical(q_tm1, transitions.reward,
                                       discount * transitions.discount, q_t)
      critic_loss = tf.reduce_mean(critic_loss, axis=[0])

      # Actor learning.
      dpg_a_t = self._policy_network(o_t)
      dpg_z_t = self._critic_network(o_t, dpg_a_t)
      dpg_q_t = dpg_z_t.mean()

      # Actor loss. If clipping is true use dqda clipping and clip the norm.
      dqda_clipping = 1.0 if self._clipping else None
      policy_loss = losses.dpg(
          dpg_q_t,
          dpg_a_t,
          tape=tape,
          dqda_clipping=dqda_clipping,
          clip_norm=self._clipping)
      policy_loss = tf.reduce_mean(policy_loss, axis=[0])

    # Get trainable variables.
    policy_variables = self._policy_network.trainable_variables
    critic_variables = (
        # In this agent, the critic loss trains the observation network.
        self._observation_network.trainable_variables +
        self._critic_network.trainable_variables)

    # Compute gradients.
    replica_context = tf.distribute.get_replica_context()
    policy_gradients = _average_gradients_across_replicas(
        replica_context,
        tape.gradient(policy_loss, policy_variables))
    critic_gradients = _average_gradients_across_replicas(
        replica_context,
        tape.gradient(critic_loss, critic_variables))

    # Delete the tape manually because of the persistent=True flag.
    del tape

    # Maybe clip gradients.
    if self._clipping:
      policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0]
      critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0]

    # Apply gradients.
    self._policy_optimizer.apply(policy_gradients, policy_variables)
    self._critic_optimizer.apply(critic_gradients, critic_variables)

    # Losses to track.
    return {
        'critic_loss': critic_loss,
        'policy_loss': policy_loss,
    }
Exemple #9
0
  def _step(self) -> Dict[str, tf.Tensor]:
    # Get data from replay (dropping extras if any). Note there is no
    # extra data here because we do not insert any into Reverb.
    sample = next(self._iterator)
    o_tm1, a_tm1, r_t, d_t, o_t = sample.data[:5]

    # Cast the additional discount to match the environment discount dtype.
    discount = tf.cast(self._discount, dtype=d_t.dtype)

    q_t = self._target_critic_network(o_t,
                                      self._policy_network(o_t))
    if not self._distributional and self._vmin is not None:
      q_t = tf.clip_by_value(q_t, self._vmin, self._vmax)
      logging.info('Clip target critic network output with [%f, %f]',
                   self._vmin, self._vmax)

    with tf.GradientTape() as tape:
      # Critic learning.
      q_tm1 = self._critic_network(o_tm1, a_tm1)

      # Critic loss.
      if self._distributional:
        critic_loss = losses.categorical(q_tm1, r_t, discount * d_t, q_t)
      else:
        # Squeeze into the shape expected by the td_learning implementation.
        q_tm1 = tf.squeeze(q_tm1, axis=-1)  # [B]
        q_t = tf.squeeze(q_t, axis=-1)  # [B]
        critic_loss = trfl.td_learning(q_tm1, r_t, discount * d_t, q_t).loss

      critic_loss = tf.reduce_mean(critic_loss, axis=[0])

    # Get trainable variables.
    critic_variables = self._critic_network.trainable_variables

    # Compute gradients.
    critic_gradients = tape.gradient(critic_loss, critic_variables)

    # Maybe clip gradients.
    if self._clipping:
      critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0]

    # Apply gradients.
    self._critic_optimizer.apply(critic_gradients, critic_variables)

    source_variables = self._critic_network.variables
    target_variables = self._target_critic_network.variables

    # Make online -> target network update ops.
    if tf.math.mod(self._num_steps, self._target_update_period) == 0:
      for src, dest in zip(source_variables, target_variables):
        dest.assign(src)

    if self._init_observations is not None:
      if tf.math.mod(self._num_steps, 100) == 0:
        # init_obs = tf.convert_to_tensor(self._init_observations, tf.float32)
        init_obs = tree.map_structure(tf.convert_to_tensor,
                                      self._init_observations)
        init_actions = self._policy_network(init_obs)
        init_critic = tf.reduce_mean(self._critic_mean(init_obs, init_actions))
      else:
        init_critic = tf.constant(0.)
    else:
      init_critic = tf.constant(0.)

    self._num_steps.assign_add(1)

    # Losses to track.
    return {
        'critic_loss': critic_loss,
        'q_s0': init_critic,
    }