Пример #1
0
def _compute_distributional_critic_loss(
    sampled_q_t_all: List[tf.Tensor],
    q_tm1_all: List[tf.Tensor],
    r_t_all: tf.Tensor,
    d_t: tf.Tensor,
    discount: float,
    num_samples: int):
  """Compute loss and sampled Q-values for distributional critics."""
  # Compute average logits by first reshaping them and normalizing them
  # across atoms.
  batch_size = r_t_all.get_shape()[0]
  # Cast the additional discount to match the environment discount dtype.
  discount = tf.cast(discount, dtype=d_t.dtype)
  critic_losses = []
  sampled_q_ts = []
  for idx, (sampled_q_t_distributions, q_tm1_distribution) in enumerate(
      zip(sampled_q_t_all, q_tm1_all)):
    # Compute loss for distributional critic for objective c
    sampled_logits = tf.reshape(
        sampled_q_t_distributions.logits,
        [num_samples, batch_size, -1])  # [N, B, A]
    sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1)
    averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0)

    # Construct the expected distributional value for bootstrapping.
    q_t_distribution = networks.DiscreteValuedDistribution(
        values=sampled_q_t_distributions.values, logits=averaged_logits)

    # Compute critic distributional loss.
    critic_loss = losses.categorical(
        q_tm1_distribution, r_t_all[:, idx], discount * d_t,
        q_t_distribution)
    critic_losses.append(tf.reduce_mean(critic_loss))

    # Compute Q-values of sampled actions and reshape to [N, B].
    sampled_q_ts.append(tf.reshape(
        sampled_q_t_distributions.mean(), (num_samples, -1)))

  critic_loss = tf.reduce_mean(critic_losses)
  sampled_q_t = tf.stack(sampled_q_ts, axis=-1)  # [N, B, C]
  return critic_loss, sampled_q_t
Пример #2
0
    def _step(self, sample: reverb.ReplaySample) -> Dict[str, tf.Tensor]:
        # Transpose batch and sequence axes, i.e. [B, T, ...] to [T, B, ...].
        sample = tf2_utils.batch_to_sequence(sample)
        observations = sample.observation
        actions = sample.action
        rewards = sample.reward
        discounts = sample.discount

        dtype = rewards.dtype

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=discounts.dtype)

        # Loss cumulants across time. These cannot be python mutable objects.
        critic_loss = 0.
        policy_loss = 0.

        # Each transition induces a policy loss, which we then weight using
        # the `policy_loss_coef_t`; shape [B], see https://arxiv.org/abs/2006.15134.
        # `policy_loss_coef` is a scalar average of these coefficients across
        # the batch and sequence length dimensions.
        policy_loss_coef = 0.

        per_device_batch_size = actions.shape[1]

        # Initialize recurrent states.
        critic_state = self._critic_network.initial_state(
            per_device_batch_size)
        target_critic_state = critic_state
        policy_state = self._policy_network.initial_state(
            per_device_batch_size)
        target_policy_state = policy_state

        with tf.GradientTape(persistent=True) as tape:
            for t in range(1, self._sequence_length):
                o_tm1 = tree.map_structure(operator.itemgetter(t - 1),
                                           observations)
                a_tm1 = tree.map_structure(operator.itemgetter(t - 1), actions)
                r_t = tree.map_structure(operator.itemgetter(t - 1), rewards)
                d_t = tree.map_structure(operator.itemgetter(t - 1), discounts)
                o_t = tree.map_structure(operator.itemgetter(t), observations)

                if t != 1:
                    # By only updating the target critic state here we are forcing
                    # the target critic to ignore observations[0]. Otherwise, the
                    # target_critic will be unrolled for one more timestep than critic.
                    # The smaller the sequence length, the more problematic this is: if
                    # you use RNN on sequences of length 2, you would expect the code to
                    # never use recurrent connections. But if you don't skip updating the
                    # target_critic_state on observation[0] here, it won't be the case.
                    _, target_critic_state = self._target_critic_network(
                        o_tm1, a_tm1, target_critic_state)

                # ========================= Critic learning ============================
                q_tm1, next_critic_state = self._critic_network(
                    o_tm1, a_tm1, critic_state)
                target_action_distribution, target_policy_state = self._target_policy_network(
                    o_t, target_policy_state)

                sampled_actions_t = target_action_distribution.sample(
                    self._num_action_samples_td_learning)
                # [N, B, ...]
                tiled_o_t = tf2_utils.tile_nested(
                    o_t, self._num_action_samples_td_learning)
                tiled_target_critic_state = tf2_utils.tile_nested(
                    target_critic_state, self._num_action_samples_td_learning)

                # Compute the target critic's Q-value of the sampled actions.
                sampled_q_t, _ = snt.BatchApply(self._target_critic_network)(
                    tiled_o_t, sampled_actions_t, tiled_target_critic_state)

                # Compute average logits by first reshaping them to [N, B, A] and then
                # normalizing them across atoms.
                new_shape = [
                    self._num_action_samples_td_learning, r_t.shape[0], -1
                ]
                sampled_logits = tf.reshape(sampled_q_t.logits, new_shape)
                sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1)
                averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0)

                # Construct the expected distributional value for bootstrapping.
                q_t = networks.DiscreteValuedDistribution(
                    values=sampled_q_t.values, logits=averaged_logits)
                critic_loss_t = losses.categorical(q_tm1, r_t, discount * d_t,
                                                   q_t)
                critic_loss_t = tf.reduce_mean(critic_loss_t)

                # ========================= Actor learning =============================
                action_distribution_tm1, policy_state = self._policy_network(
                    o_tm1, policy_state)
                q_tm1_mean = q_tm1.mean()

                # Compute the estimate of the value function based on
                # self._num_action_samples_policy_weight samples from the policy.
                tiled_o_tm1 = tf2_utils.tile_nested(
                    o_tm1, self._num_action_samples_policy_weight)
                tiled_critic_state = tf2_utils.tile_nested(
                    critic_state, self._num_action_samples_policy_weight)
                action_tm1 = action_distribution_tm1.sample(
                    self._num_action_samples_policy_weight)
                tiled_z_tm1, _ = snt.BatchApply(self._critic_network)(
                    tiled_o_tm1, action_tm1, tiled_critic_state)
                tiled_v_tm1 = tf.reshape(
                    tiled_z_tm1.mean(),
                    [self._num_action_samples_policy_weight, -1])

                # Use mean, min, or max to aggregate Q(s, a_i), a_i ~ pi(s) into the
                # final estimate of the value function.
                if self._baseline_reduce_function == 'mean':
                    v_tm1_estimate = tf.reduce_mean(tiled_v_tm1, axis=0)
                elif self._baseline_reduce_function == 'max':
                    v_tm1_estimate = tf.reduce_max(tiled_v_tm1, axis=0)
                elif self._baseline_reduce_function == 'min':
                    v_tm1_estimate = tf.reduce_min(tiled_v_tm1, axis=0)

                # Assert that action_distribution_tm1 is a batch of multivariate
                # distributions (in contrast to e.g. a [batch, action_size] collection
                # of 1d distributions).
                assert len(action_distribution_tm1.batch_shape) == 1
                policy_loss_batch = -action_distribution_tm1.log_prob(a_tm1)

                advantage = q_tm1_mean - v_tm1_estimate
                if self._policy_improvement_modes == 'exp':
                    policy_loss_coef_t = tf.math.minimum(
                        tf.math.exp(advantage / self._beta),
                        self._ratio_upper_bound)
                elif self._policy_improvement_modes == 'binary':
                    policy_loss_coef_t = tf.cast(advantage > 0, dtype=dtype)
                elif self._policy_improvement_modes == 'all':
                    # Regress against all actions (effectively pure BC).
                    policy_loss_coef_t = 1.
                policy_loss_coef_t = tf.stop_gradient(policy_loss_coef_t)

                policy_loss_batch *= policy_loss_coef_t
                policy_loss_t = tf.reduce_mean(policy_loss_batch)

                critic_state = next_critic_state

                critic_loss += critic_loss_t
                policy_loss += policy_loss_t
                policy_loss_coef += tf.reduce_mean(
                    policy_loss_coef_t)  # For logging.

            # Divide by sequence length to get mean losses.
            critic_loss /= tf.cast(self._sequence_length, dtype=dtype)
            policy_loss /= tf.cast(self._sequence_length, dtype=dtype)
            policy_loss_coef /= tf.cast(self._sequence_length, dtype=dtype)

        # Compute gradients.
        critic_gradients = tape.gradient(
            critic_loss, self._critic_network.trainable_variables)
        policy_gradients = tape.gradient(
            policy_loss, self._policy_network.trainable_variables)

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Sync gradients across GPUs or TPUs.
        ctx = tf.distribute.get_replica_context()
        critic_gradients = ctx.all_reduce('mean', critic_gradients)
        policy_gradients = ctx.all_reduce('mean', policy_gradients)

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0]
            critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0]

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     self._critic_network.trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     self._policy_network.trainable_variables)

        source_variables = (self._critic_network.variables +
                            self._policy_network.variables)
        target_variables = (self._target_critic_network.variables +
                            self._target_policy_network.variables)

        # Make online -> target network update ops.
        if tf.math.mod(self._num_steps, self._target_update_period) == 0:
            for src, dest in zip(source_variables, target_variables):
                dest.assign(src)
        self._num_steps.assign_add(1)

        return {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
            'policy_loss_coef': policy_loss_coef,
        }
Пример #3
0
    def _step(self) -> types.NestedTensor:
        # Update target network.
        online_policy_variables = self._policy_network.variables
        target_policy_variables = self._target_policy_network.variables
        online_critic_variables = (
            *self._observation_network.variables,
            *self._critic_network.variables,
        )
        target_critic_variables = (
            *self._target_observation_network.variables,
            *self._target_critic_network.variables,
        )

        # Make online policy -> target policy network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_policy_update_period) == 0:
            for src, dest in zip(online_policy_variables,
                                 target_policy_variables):
                dest.assign(src)
        # Make online critic -> target critic network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_critic_update_period) == 0:
            for src, dest in zip(online_critic_variables,
                                 target_critic_variables):
                dest.assign(src)

        self._num_steps.assign_add(1)

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t = inputs.data

        # Get batch size and scalar dtype.
        batch_size = r_t.shape[0]

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=d_t.dtype)

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(o_tm1)
            # This stop_gradient prevents gradients to propagate into the target
            # observation network. In addition, since the online policy network is
            # evaluated at o_t, this also means the policy loss does not influence
            # the observation network training.
            o_t = tf.stop_gradient(self._target_observation_network(o_t))

            # Get online and target action distributions from policy networks.
            online_action_distribution = self._policy_network(o_t)
            target_action_distribution = self._target_policy_network(o_t)

            # Sample actions to evaluate policy; of size [N, B, ...].
            sampled_actions = target_action_distribution.sample(
                self._num_samples)

            # Tile embedded observations to feed into the target critic network.
            # Note: this is more efficient than tiling before the embedding layer.
            tiled_o_t = tf2_utils.tile_tensor(o_t,
                                              self._num_samples)  # [N, B, ...]

            # Compute target-estimated distributional value of sampled actions at o_t.
            sampled_q_t_distributions = self._target_critic_network(
                # Merge batch dimensions; to shape [N*B, ...].
                snt.merge_leading_dims(tiled_o_t, num_dims=2),
                snt.merge_leading_dims(sampled_actions, num_dims=2))

            # Compute average logits by first reshaping them and normalizing them
            # across atoms.
            new_shape = [self._num_samples, batch_size, -1]  # [N, B, A]
            sampled_logits = tf.reshape(sampled_q_t_distributions.logits,
                                        new_shape)
            sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1)
            averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0)

            # Construct the expected distributional value for bootstrapping.
            q_t_distribution = networks.DiscreteValuedDistribution(
                values=sampled_q_t_distributions.values,
                logits=averaged_logits)

            # Compute online critic value distribution of a_tm1 in state o_tm1.
            q_tm1_distribution = self._critic_network(o_tm1, a_tm1)

            # Compute critic distributional loss.
            critic_loss = losses.categorical(q_tm1_distribution, r_t,
                                             discount * d_t, q_t_distribution)
            critic_loss = tf.reduce_mean(critic_loss)

            # Compute Q-values of sampled actions and reshape to [N, B].
            sampled_q_values = sampled_q_t_distributions.mean()
            sampled_q_values = tf.reshape(sampled_q_values,
                                          (self._num_samples, -1))

            # Compute MPO policy loss.
            policy_loss, policy_stats = self._policy_loss_module(
                online_action_distribution=online_action_distribution,
                target_action_distribution=target_action_distribution,
                actions=sampled_actions,
                q_values=sampled_q_values)

        # For clarity, explicitly define which variables are trained by which loss.
        critic_trainable_variables = (
            # In this agent, the critic loss trains the observation network.
            self._observation_network.trainable_variables +
            self._critic_network.trainable_variables)
        policy_trainable_variables = self._policy_network.trainable_variables
        # The following are the MPO dual variables, stored in the loss module.
        dual_trainable_variables = self._policy_loss_module.trainable_variables

        # Compute gradients.
        critic_gradients = tape.gradient(critic_loss,
                                         critic_trainable_variables)
        policy_gradients, dual_gradients = tape.gradient(
            policy_loss,
            (policy_trainable_variables, dual_trainable_variables))

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tuple(
                tf.clip_by_global_norm(policy_gradients, 40.)[0])
            critic_gradients = tuple(
                tf.clip_by_global_norm(critic_gradients, 40.)[0])

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     critic_trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     policy_trainable_variables)
        self._dual_optimizer.apply(dual_gradients, dual_trainable_variables)

        # Losses to track.
        fetches = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }
        fetches.update(policy_stats)  # Log MPO stats.

        return fetches
Пример #4
0
    def _step(self) -> types.NestedTensor:
        # Update target network.
        online_policy_variables = self._policy_network.variables
        target_policy_variables = self._target_policy_network.variables
        online_critic_variables = (
            *self._observation_network.variables,
            *self._critic_network.variables,
        )
        target_critic_variables = (
            *self._target_observation_network.variables,
            *self._target_critic_network.variables,
        )

        # Make online policy -> target policy network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_policy_update_period) == 0:
            for src, dest in zip(online_policy_variables,
                                 target_policy_variables):
                dest.assign(src)
        # Make online critic -> target critic network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_critic_update_period) == 0:
            for src, dest in zip(online_critic_variables,
                                 target_critic_variables):
                dest.assign(src)

        self._num_steps.assign_add(1)

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t = inputs.data

        # Get batch size and scalar dtype.
        batch_size = r_t.shape[0]

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=d_t.dtype)

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(o_tm1)
            # This stop_gradient prevents gradients to propagate into the target
            # observation network. In addition, since the online policy network is
            # evaluated at o_t, this also means the policy loss does not influence
            # the observation network training.
            o_t = tf.stop_gradient(self._target_observation_network(o_t))

            # Get online and target action distributions from policy networks.
            online_action_distribution = self._policy_network(o_t)
            target_action_distribution = self._target_policy_network(o_t)

            # Sample actions to evaluate policy; of size [N, B, ...].
            sampled_actions = target_action_distribution.sample(
                self._num_samples)

            # Tile embedded observations to feed into the target critic network.
            # Note: this is more efficient than tiling before the embedding layer.
            tiled_o_t = tf2_utils.tile_tensor(o_t,
                                              self._num_samples)  # [N, B, ...]

            # Compute target-estimated distributional value of sampled actions at o_t.
            sampled_q_t_all = self._target_critic_network(
                # Merge batch dimensions; to shape [N*B, ...].
                snt.merge_leading_dims(tiled_o_t, num_dims=2),
                snt.merge_leading_dims(sampled_actions, num_dims=2))

            # Compute online critic value distribution of a_tm1 in state o_tm1.
            q_tm1_all = self._critic_network(o_tm1, a_tm1)

            # Compute rewards for objectives with defined reward_fn
            reward_stats = {}
            r_t_all = []
            for objective in self._objectives:
                if hasattr(objective, 'reward_fn'):
                    r = objective.reward_fn(o_tm1, a_tm1, r_t)
                    reward_stats['{}_reward'.format(
                        objective.name)] = tf.reduce_mean(r)
                    r_t_all.append(r)
            r_t_all = tf.stack(r_t_all, axis=-1)
            r_t_all.get_shape().assert_has_rank(2)  # [B, C]

            if isinstance(sampled_q_t_all, list):  # Distributional critics
                # Compute average logits by first reshaping them and normalizing them
                # across atoms.
                critic_losses = []
                sampled_q_ts = []
                for idx, (sampled_q_t_distributions,
                          q_tm1_distribution) in enumerate(
                              zip(sampled_q_t_all, q_tm1_all)):
                    # Compute loss for distributional critic for objective c
                    sampled_logits = tf.reshape(
                        sampled_q_t_distributions.logits,
                        [self._num_samples, batch_size, -1])  # [N, B, A]
                    sampled_logprobs = tf.math.log_softmax(sampled_logits,
                                                           axis=-1)
                    averaged_logits = tf.reduce_logsumexp(sampled_logprobs,
                                                          axis=0)

                    # Construct the expected distributional value for bootstrapping.
                    q_t_distribution = networks.DiscreteValuedDistribution(
                        values=sampled_q_t_distributions.values,
                        logits=averaged_logits)

                    # Compute critic distributional loss.
                    critic_loss = losses.categorical(q_tm1_distribution,
                                                     r_t_all[:, idx],
                                                     discount * d_t,
                                                     q_t_distribution)
                    critic_losses.append(tf.reduce_mean(critic_loss))

                    # Compute Q-values of sampled actions and reshape to [N, B].
                    sampled_q_ts.append(
                        tf.reshape(sampled_q_t_distributions.mean(),
                                   (self._num_samples, -1)))

                critic_loss = tf.reduce_mean(critic_losses)
                sampled_q_t = tf.stack(sampled_q_ts, axis=-1)  # [N, B, C]
            else:
                # Reshape Q-value samples back to original batch dimensions and average
                # them to compute the TD-learning bootstrap target.
                sampled_q_t = tf.reshape(sampled_q_t_all,
                                         (self._num_samples, batch_size,
                                          self._num_critic_heads))  # [N,B,C]
                q_t = tf.reduce_mean(sampled_q_t, axis=0)  # [B, C]

                # Flatten q_t and q_tm1; necessary for trfl.td_learning
                q_t = tf.reshape(q_t, [-1])  # [B*C]
                q_tm1 = tf.reshape(q_tm1_all, [-1])  # [B*C]

                # Flatten r_t_all; necessary for trfl.td_learning
                r_t_all = tf.reshape(r_t_all, [-1])  # [B*C]

                # Broadcast and then flatten d_t, to match shape of q_t and q_tm1
                d_t = tf.tile(d_t, [self._num_critic_heads])  # [B*C]

                # Critic loss.
                critic_loss = trfl.td_learning(q_tm1, r_t_all, discount * d_t,
                                               q_t).loss
                critic_loss = tf.reduce_mean(critic_loss)

            # Add sampled Q-values for objectives with defined objective_fn
            sampled_q_idx = 0
            sampled_q_t_k = []
            for objective in self._objectives:
                if hasattr(objective, 'reward_fn'):
                    sampled_q_t_k.append(
                        tf.stop_gradient(sampled_q_t[..., sampled_q_idx]))
                    sampled_q_idx += 1
                if hasattr(objective, 'objective_fn'):
                    sampled_q_t_k.append(
                        tf.stop_gradient(
                            objective.objective_fn(sampled_actions,
                                                   sampled_q_t)))
            sampled_q_t_k = tf.stack(sampled_q_t_k, axis=-1)  # [N, B, K]

            # Compute MPO policy loss.
            policy_loss, policy_stats = self._policy_loss_module(
                online_action_distribution=online_action_distribution,
                target_action_distribution=target_action_distribution,
                actions=sampled_actions,
                q_values=sampled_q_t_k)

        # For clarity, explicitly define which variables are trained by which loss.
        critic_trainable_variables = (
            # In this agent, the critic loss trains the observation network.
            self._observation_network.trainable_variables +
            self._critic_network.trainable_variables)
        policy_trainable_variables = self._policy_network.trainable_variables
        # The following are the MPO dual variables, stored in the loss module.
        dual_trainable_variables = self._policy_loss_module.trainable_variables

        # Compute gradients.
        critic_gradients = tape.gradient(critic_loss,
                                         critic_trainable_variables)
        policy_gradients, dual_gradients = tape.gradient(
            policy_loss,
            (policy_trainable_variables, dual_trainable_variables))

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tuple(
                tf.clip_by_global_norm(policy_gradients, 40.)[0])
            critic_gradients = tuple(
                tf.clip_by_global_norm(critic_gradients, 40.)[0])

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     critic_trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     policy_trainable_variables)
        self._dual_optimizer.apply(dual_gradients, dual_trainable_variables)

        # Losses to track.
        fetches = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }
        fetches.update(policy_stats)  # Log MPO stats.
        fetches.update(reward_stats)  # Log reward stats.

        return fetches