Beispiel #1
0
    def _step(self, sample: reverb.ReplaySample) -> Dict[str, tf.Tensor]:
        # Transpose batch and sequence axes, i.e. [B, T, ...] to [T, B, ...].
        sample = tf2_utils.batch_to_sequence(sample)
        observations = sample.observation
        actions = sample.action
        rewards = sample.reward
        discounts = sample.discount

        dtype = rewards.dtype

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=discounts.dtype)

        # Loss cumulants across time. These cannot be python mutable objects.
        critic_loss = 0.
        policy_loss = 0.

        # Each transition induces a policy loss, which we then weight using
        # the `policy_loss_coef_t`; shape [B], see https://arxiv.org/abs/2006.15134.
        # `policy_loss_coef` is a scalar average of these coefficients across
        # the batch and sequence length dimensions.
        policy_loss_coef = 0.

        per_device_batch_size = actions.shape[1]

        # Initialize recurrent states.
        critic_state = self._critic_network.initial_state(
            per_device_batch_size)
        target_critic_state = critic_state
        policy_state = self._policy_network.initial_state(
            per_device_batch_size)
        target_policy_state = policy_state

        with tf.GradientTape(persistent=True) as tape:
            for t in range(1, self._sequence_length):
                o_tm1 = tree.map_structure(operator.itemgetter(t - 1),
                                           observations)
                a_tm1 = tree.map_structure(operator.itemgetter(t - 1), actions)
                r_t = tree.map_structure(operator.itemgetter(t - 1), rewards)
                d_t = tree.map_structure(operator.itemgetter(t - 1), discounts)
                o_t = tree.map_structure(operator.itemgetter(t), observations)

                if t != 1:
                    # By only updating the target critic state here we are forcing
                    # the target critic to ignore observations[0]. Otherwise, the
                    # target_critic will be unrolled for one more timestep than critic.
                    # The smaller the sequence length, the more problematic this is: if
                    # you use RNN on sequences of length 2, you would expect the code to
                    # never use recurrent connections. But if you don't skip updating the
                    # target_critic_state on observation[0] here, it won't be the case.
                    _, target_critic_state = self._target_critic_network(
                        o_tm1, a_tm1, target_critic_state)

                # ========================= Critic learning ============================
                q_tm1, next_critic_state = self._critic_network(
                    o_tm1, a_tm1, critic_state)
                target_action_distribution, target_policy_state = self._target_policy_network(
                    o_t, target_policy_state)

                sampled_actions_t = target_action_distribution.sample(
                    self._num_action_samples_td_learning)
                # [N, B, ...]
                tiled_o_t = tf2_utils.tile_nested(
                    o_t, self._num_action_samples_td_learning)
                tiled_target_critic_state = tf2_utils.tile_nested(
                    target_critic_state, self._num_action_samples_td_learning)

                # Compute the target critic's Q-value of the sampled actions.
                sampled_q_t, _ = snt.BatchApply(self._target_critic_network)(
                    tiled_o_t, sampled_actions_t, tiled_target_critic_state)

                # Compute average logits by first reshaping them to [N, B, A] and then
                # normalizing them across atoms.
                new_shape = [
                    self._num_action_samples_td_learning, r_t.shape[0], -1
                ]
                sampled_logits = tf.reshape(sampled_q_t.logits, new_shape)
                sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1)
                averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0)

                # Construct the expected distributional value for bootstrapping.
                q_t = networks.DiscreteValuedDistribution(
                    values=sampled_q_t.values, logits=averaged_logits)
                critic_loss_t = losses.categorical(q_tm1, r_t, discount * d_t,
                                                   q_t)
                critic_loss_t = tf.reduce_mean(critic_loss_t)

                # ========================= Actor learning =============================
                action_distribution_tm1, policy_state = self._policy_network(
                    o_tm1, policy_state)
                q_tm1_mean = q_tm1.mean()

                # Compute the estimate of the value function based on
                # self._num_action_samples_policy_weight samples from the policy.
                tiled_o_tm1 = tf2_utils.tile_nested(
                    o_tm1, self._num_action_samples_policy_weight)
                tiled_critic_state = tf2_utils.tile_nested(
                    critic_state, self._num_action_samples_policy_weight)
                action_tm1 = action_distribution_tm1.sample(
                    self._num_action_samples_policy_weight)
                tiled_z_tm1, _ = snt.BatchApply(self._critic_network)(
                    tiled_o_tm1, action_tm1, tiled_critic_state)
                tiled_v_tm1 = tf.reshape(
                    tiled_z_tm1.mean(),
                    [self._num_action_samples_policy_weight, -1])

                # Use mean, min, or max to aggregate Q(s, a_i), a_i ~ pi(s) into the
                # final estimate of the value function.
                if self._baseline_reduce_function == 'mean':
                    v_tm1_estimate = tf.reduce_mean(tiled_v_tm1, axis=0)
                elif self._baseline_reduce_function == 'max':
                    v_tm1_estimate = tf.reduce_max(tiled_v_tm1, axis=0)
                elif self._baseline_reduce_function == 'min':
                    v_tm1_estimate = tf.reduce_min(tiled_v_tm1, axis=0)

                # Assert that action_distribution_tm1 is a batch of multivariate
                # distributions (in contrast to e.g. a [batch, action_size] collection
                # of 1d distributions).
                assert len(action_distribution_tm1.batch_shape) == 1
                policy_loss_batch = -action_distribution_tm1.log_prob(a_tm1)

                advantage = q_tm1_mean - v_tm1_estimate
                if self._policy_improvement_modes == 'exp':
                    policy_loss_coef_t = tf.math.minimum(
                        tf.math.exp(advantage / self._beta),
                        self._ratio_upper_bound)
                elif self._policy_improvement_modes == 'binary':
                    policy_loss_coef_t = tf.cast(advantage > 0, dtype=dtype)
                elif self._policy_improvement_modes == 'all':
                    # Regress against all actions (effectively pure BC).
                    policy_loss_coef_t = 1.
                policy_loss_coef_t = tf.stop_gradient(policy_loss_coef_t)

                policy_loss_batch *= policy_loss_coef_t
                policy_loss_t = tf.reduce_mean(policy_loss_batch)

                critic_state = next_critic_state

                critic_loss += critic_loss_t
                policy_loss += policy_loss_t
                policy_loss_coef += tf.reduce_mean(
                    policy_loss_coef_t)  # For logging.

            # Divide by sequence length to get mean losses.
            critic_loss /= tf.cast(self._sequence_length, dtype=dtype)
            policy_loss /= tf.cast(self._sequence_length, dtype=dtype)
            policy_loss_coef /= tf.cast(self._sequence_length, dtype=dtype)

        # Compute gradients.
        critic_gradients = tape.gradient(
            critic_loss, self._critic_network.trainable_variables)
        policy_gradients = tape.gradient(
            policy_loss, self._policy_network.trainable_variables)

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Sync gradients across GPUs or TPUs.
        ctx = tf.distribute.get_replica_context()
        critic_gradients = ctx.all_reduce('mean', critic_gradients)
        policy_gradients = ctx.all_reduce('mean', policy_gradients)

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0]
            critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0]

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     self._critic_network.trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     self._policy_network.trainable_variables)

        source_variables = (self._critic_network.variables +
                            self._policy_network.variables)
        target_variables = (self._target_critic_network.variables +
                            self._target_policy_network.variables)

        # Make online -> target network update ops.
        if tf.math.mod(self._num_steps, self._target_update_period) == 0:
            for src, dest in zip(source_variables, target_variables):
                dest.assign(src)
        self._num_steps.assign_add(1)

        return {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
            'policy_loss_coef': policy_loss_coef,
        }
Beispiel #2
0
    def _step(self) -> Dict[str, tf.Tensor]:
        # Update target network
        online_variables = [
            *self._critic_network.variables,
            *self._policy_network.variables,
        ]
        if self._prior_network is not None:
            online_variables += [*self._prior_network.variables]
        online_variables = tuple(online_variables)

        target_variables = [
            *self._target_critic_network.variables,
            *self._target_policy_network.variables,
        ]
        if self._prior_network is not None:
            target_variables += [*self._target_prior_network.variables]
        target_variables = tuple(target_variables)

        # Make online -> target network update ops.
        if tf.math.mod(self._num_steps, self._target_update_period) == 0:
            for src, dest in zip(online_variables, target_variables):
                dest.assign(src)
        self._num_steps.assign_add(1)

        # Get data from replay (dropping extras if any) and flip to `[T, B, ...]`.
        sample: reverb.ReplaySample = next(self._iterator)
        data = tf2_utils.batch_to_sequence(sample.data)
        observations, actions, rewards, discounts, extra = (data.observation,
                                                            data.action,
                                                            data.reward,
                                                            data.discount,
                                                            data.extras)
        online_target_pi_q = svg0_utils.OnlineTargetPiQ(
            online_pi=self._policy_network,
            online_q=self._critic_network,
            target_pi=self._target_policy_network,
            target_q=self._target_critic_network,
            num_samples=self._num_action_samples,
            online_prior=self._prior_network,
            target_prior=self._target_prior_network,
        )
        with tf.GradientTape(persistent=True) as tape:
            step_outputs = svg0_utils.static_rnn(
                core=online_target_pi_q,
                inputs=(observations, actions),
                unroll_length=rewards.shape[0])

            # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the
            # number of action samples taken.
            target_pi_samples = tf2_utils.batch_to_sequence(
                step_outputs.target_samples)
            # Tile observations to have shape [S, T+1, B,..].
            tiled_observations = tf2_utils.tile_nested(
                observations, self._num_action_samples)

            # Finally compute target Q values on the new action samples.
            # Shape: [S, T+1, B, 1]
            target_q_target_pi_samples = snt.BatchApply(
                self._target_critic_network, 3)(tiled_observations,
                                                target_pi_samples)
            # Compute the value estimate by averaging over the action dimension.
            # Shape: [T+1, B, 1].
            target_v_target_pi = tf.reduce_mean(target_q_target_pi_samples,
                                                axis=0)

            # Split the target V's into the target for learning
            # `value_function_target` and the bootstrap value. Shape: [T, B].
            value_function_target = tf.squeeze(target_v_target_pi[:-1],
                                               axis=-1)
            # Shape: [B].
            bootstrap_value = tf.squeeze(target_v_target_pi[-1], axis=-1)

            # When learning with a prior, add entropy terms to value targets.
            if self._prior_network is not None:
                value_function_target -= self._distillation_cost * tf.stop_gradient(
                    step_outputs.analytic_kl_to_target[:-1])
                bootstrap_value -= self._distillation_cost * tf.stop_gradient(
                    step_outputs.analytic_kl_to_target[-1])

            # Get target log probs and behavior log probs from rollout.
            # Shape: [T+1, B].
            target_log_probs_behavior_actions = (
                step_outputs.target_log_probs_behavior_actions)
            behavior_log_probs = extra['log_prob']
            # Calculate importance weights. Shape: [T+1, B].
            rhos = tf.exp(target_log_probs_behavior_actions -
                          behavior_log_probs)

            # Filter the importance weights to mask out episode restarts. Ignore the
            # last action and consider the step type of the next step for masking.
            # Shape: [T, B].
            episode_start_mask = tf2_utils.batch_to_sequence(
                sample.data.start_of_episode)[1:]

            rhos = svg0_utils.mask_out_restarting(rhos[:-1],
                                                  episode_start_mask)

            # rhos = rhos[:-1]
            # Compute the log importance weights with a small value added for
            # stability.
            # Shape: [T, B]
            log_rhos = tf.math.log(rhos + _MIN_LOG_VAL)

            # Retrieve the target and online Q values and throw away the last action.
            # Shape: [T, B].
            target_q_values = tf.squeeze(step_outputs.target_q[:-1], -1)
            online_q_values = tf.squeeze(step_outputs.online_q[:-1], -1)

            # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the
            # number of action samples taken.
            online_pi_samples = tf2_utils.batch_to_sequence(
                step_outputs.online_samples)
            target_q_online_pi_samples = snt.BatchApply(
                self._target_critic_network, 3)(tiled_observations,
                                                online_pi_samples)
            expected_q = tf.reduce_mean(tf.squeeze(target_q_online_pi_samples,
                                                   -1),
                                        axis=0)

            # Flip online_log_probs to be of shape [S, T+1, B] and then compute
            # entropy by averaging over num samples. Final shape: [T+1, B].
            online_log_probs = tf2_utils.batch_to_sequence(
                step_outputs.online_log_probs)
            sample_based_entropy = tf.reduce_mean(-online_log_probs, axis=0)
            retrace_outputs = continuous_retrace_ops.retrace_from_importance_weights(
                log_rhos=log_rhos,
                discounts=self._discount * discounts[:-1],
                rewards=rewards[:-1],
                q_values=target_q_values,
                values=value_function_target,
                bootstrap_value=bootstrap_value,
                lambda_=self._lambda,
            )

            # Critic loss. Shape: [T, B].
            critic_loss = 0.5 * tf.math.squared_difference(
                tf.stop_gradient(retrace_outputs.qs), online_q_values)

            # Policy loss- SVG0 with sample based entropy. Shape: [T, B]
            policy_loss = -(expected_q + self._entropy_regularizer_cost *
                            sample_based_entropy)
            policy_loss = policy_loss[:-1]

            if self._prior_network is not None:
                # When training the prior, also add the per-timestep KL cost.
                policy_loss += (self._distillation_cost *
                                step_outputs.analytic_kl_to_target[:-1])

            # Ensure episode restarts are masked out when computing the losses.
            critic_loss = svg0_utils.mask_out_restarting(
                critic_loss, episode_start_mask)
            critic_loss = tf.reduce_mean(critic_loss)

            policy_loss = svg0_utils.mask_out_restarting(
                policy_loss, episode_start_mask)
            policy_loss = tf.reduce_mean(policy_loss)

            if self._prior_network is not None:
                prior_loss = step_outputs.analytic_kl_divergence[:-1]
                prior_loss = svg0_utils.mask_out_restarting(
                    prior_loss, episode_start_mask)
                prior_loss = tf.reduce_mean(prior_loss)

        # Get trainable variables.
        policy_variables = self._policy_network.trainable_variables
        critic_variables = self._critic_network.trainable_variables

        # Compute gradients.
        policy_gradients = tape.gradient(policy_loss, policy_variables)
        critic_gradients = tape.gradient(critic_loss, critic_variables)
        if self._prior_network is not None:
            prior_variables = self._prior_network.trainable_variables
            prior_gradients = tape.gradient(prior_loss, prior_variables)

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Apply gradients.
        self._policy_optimizer.apply(policy_gradients, policy_variables)
        self._critic_optimizer.apply(critic_gradients, critic_variables)
        losses = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }

        if self._prior_network is not None:
            self._prior_optimizer.apply(prior_gradients, prior_variables)
            losses['prior_loss'] = prior_loss

        # Losses to track.
        return losses
Beispiel #3
0
  def _step(self) -> Dict[str, tf.Tensor]:

    # Draw a batch of data from replay.
    sample: reverb.ReplaySample = next(self._iterator)

    data = tf2_utils.batch_to_sequence(sample.data)
    observations, actions, rewards, discounts, extra = (data.observation,
                                                        data.action,
                                                        data.reward,
                                                        data.discount,
                                                        data.extras)
    unused_sequence_length, batch_size = actions.shape

    # Get initial state for the LSTM, either from replay or simply use zeros.
    if self._store_lstm_state:
      core_state = tree.map_structure(lambda x: x[0], extra['core_state'])
    else:
      core_state = self._network.initial_state(batch_size)
    target_core_state = tree.map_structure(tf.identity, core_state)

    # Before training, optionally unroll the LSTM for a fixed warmup period.
    burn_in_obs = tree.map_structure(lambda x: x[:self._burn_in_length],
                                     observations)
    _, core_state = self._burn_in(burn_in_obs, core_state)
    _, target_core_state = self._burn_in(burn_in_obs, target_core_state)

    # Don't train on the warmup period.
    observations, actions, rewards, discounts, extra = tree.map_structure(
        lambda x: x[self._burn_in_length:],
        (observations, actions, rewards, discounts, extra))

    with tf.GradientTape() as tape:
      # Unroll the online and target Q-networks on the sequences.
      q_values, _ = self._network.unroll(observations, core_state,
                                         self._sequence_length)
      target_q_values, _ = self._target_network.unroll(observations,
                                                       target_core_state,
                                                       self._sequence_length)

      # Compute the target policy distribution (greedy).
      greedy_actions = tf.argmax(q_values, output_type=tf.int32, axis=-1)
      target_policy_probs = tf.one_hot(
          greedy_actions, depth=self._num_actions, dtype=q_values.dtype)

      # Compute the transformed n-step loss.
      rewards = tree.map_structure(lambda x: x[:-1], rewards)
      discounts = tree.map_structure(lambda x: x[:-1], discounts)
      loss, extra = losses.transformed_n_step_loss(
          qs=q_values,
          targnet_qs=target_q_values,
          actions=actions,
          rewards=rewards,
          pcontinues=discounts * self._discount,
          target_policy_probs=target_policy_probs,
          bootstrap_n=self._n_step,
      )

      # Calculate importance weights and use them to scale the loss.
      sample_info = sample.info
      keys, probs = sample_info.key, sample_info.probability
      probs = tf2_utils.batch_to_sequence(probs)
      importance_weights = 1. / (self._max_replay_size * probs)  # [T, B]
      importance_weights **= self._importance_sampling_exponent
      importance_weights /= tf.reduce_max(importance_weights)
      loss *= tf.cast(importance_weights, tf.float32)  # [T, B]
      loss = tf.reduce_mean(loss)  # []

    # Apply gradients via optimizer.
    gradients = tape.gradient(loss, self._network.trainable_variables)
    # Clip and apply gradients.
    if self._clip_grad_norm is not None:
      gradients, _ = tf.clip_by_global_norm(gradients, self._clip_grad_norm)

    self._optimizer.apply(gradients, self._network.trainable_variables)

    # Periodically update the target network.
    if tf.math.mod(self._num_steps, self._target_update_period) == 0:
      for src, dest in zip(self._network.variables,
                           self._target_network.variables):
        dest.assign(src)
    self._num_steps.assign_add(1)

    if self._reverb_client:
      # Compute updated priorities.
      priorities = compute_priority(extra.errors, self._max_priority_weight)
      # Compute priorities and add an op to update them on the reverb side.
      self._reverb_client.update_priorities(
          table=adders.DEFAULT_PRIORITY_TABLE,
          keys=keys[:, 0],
          priorities=tf.cast(priorities, tf.float64))

    return {'loss': loss}
Beispiel #4
0
    def _forward(self, inputs: Any) -> None:
        """Trainer forward pass

        Args:
            inputs (Any): input data from the data table (transitions)
        """

        # TODO: Update this forward function to work like MAD4PG
        data = inputs.data

        # Note (dries): The unused variable is start_of_episodes.
        observations, actions, rewards, discounts, _, extras = (
            data.observations,
            data.actions,
            data.rewards,
            data.discounts,
            data.start_of_episode,
            data.extras,
        )

        # Get initial state for the LSTM from replay and
        # extract the first state in the sequence..
        core_state = tree.map_structure(lambda s: s[:, 0, :],
                                        extras["core_states"])
        target_core_state = tree.map_structure(tf.identity, core_state)

        # TODO (dries): Take out all the data_points that does not need
        #  to be processed here at the start. Therefore it does not have
        #  to be done later on and saves processing time.

        self.policy_losses: Dict[str, tf.Tensor] = {}
        self.critic_losses: Dict[str, tf.Tensor] = {}

        # Do forward passes through the networks and calculate the losses
        with tf.GradientTape(persistent=True) as tape:
            # Note (dries): We are assuming that only the policy network
            # is recurrent and not the observation network.
            obs_trans, target_obs_trans = self._transform_observations(
                observations)

            target_actions = self._target_policy_actions(
                target_obs_trans, target_core_state)

            for agent in self._agents:
                agent_key = self.agent_net_keys[agent]

                # Get critic feed
                (
                    obs_trans_feed,
                    target_obs_trans_feed,
                    action_feed,
                    target_actions_feed,
                ) = self._get_critic_feed(
                    obs_trans=obs_trans,
                    target_obs_trans=target_obs_trans,
                    actions=actions,
                    target_actions=target_actions,
                    extras=extras,
                    agent=agent,
                )

                # Critic learning.
                # Remove the last sequence step for the normal network
                obs_comb, dims = train_utils.combine_dim(obs_trans_feed)
                act_comb, _ = train_utils.combine_dim(action_feed)
                q_values = self._critic_networks[agent_key](obs_comb, act_comb)
                q_values.set_dimensions(dims)

                # Remove first sequence step for the target
                obs_comb, _ = train_utils.combine_dim(target_obs_trans_feed)
                act_comb, _ = train_utils.combine_dim(target_actions_feed)
                target_q_values = self._target_critic_networks[agent_key](
                    obs_comb, act_comb)
                target_q_values.set_dimensions(dims)

                # Cast the additional discount to match
                # the environment discount dtype.
                agent_discount = discounts[agent]
                discount = tf.cast(self._discount, dtype=agent_discount.dtype)

                # Critic loss.
                critic_loss = recurrent_n_step_critic_loss(
                    q_values,
                    target_q_values,
                    rewards[agent],
                    discount * agent_discount,
                    bootstrap_n=self._bootstrap_n,
                    loss_fn=losses.categorical,
                )
                self.critic_losses[agent] = tf.reduce_mean(critic_loss, axis=0)

                # Actor learning.
                obs_agent_feed = target_obs_trans[agent]
                # TODO (dries): Why is there an extra tuple?
                agent_core_state = core_state[agent][0]
                transposed_obs = tf2_utils.batch_to_sequence(obs_agent_feed)
                outputs, updated_states = snt.static_unroll(
                    self._policy_networks[agent_key],
                    transposed_obs,
                    agent_core_state,
                )

                dpg_actions = tf2_utils.batch_to_sequence(outputs)

                # Note (dries): This is done to so that losses.dpg can verify
                # using gradient.tape that there is a
                # gradient relationship between dpg_q_values and dpg_actions_comb.
                dpg_actions_comb, dim = train_utils.combine_dim(dpg_actions)

                # Note (dries): This seemingly useless line is important!
                # Don't remove it. See above note.
                dpg_actions = train_utils.extract_dim(dpg_actions_comb, dim)

                # Get dpg actions
                dpg_actions_feed = self._get_dpg_feed(target_actions,
                                                      dpg_actions, agent)

                # Get dpg Q values.
                obs_comb, _ = train_utils.combine_dim(target_obs_trans_feed)
                act_comb, _ = train_utils.combine_dim(dpg_actions_feed)
                dpg_z_values = self._critic_networks[agent_key](obs_comb,
                                                                act_comb)
                dpg_q_values = dpg_z_values.mean()

                # Actor loss. If clipping is true use dqda clipping and clip the norm.
                dqda_clipping = 1.0 if self._max_gradient_norm is not None else None
                clip_norm = True if self._max_gradient_norm is not None else False

                policy_loss = losses.dpg(
                    dpg_q_values,
                    dpg_actions_comb,
                    tape=tape,
                    dqda_clipping=dqda_clipping,
                    clip_norm=clip_norm,
                )
                self.policy_losses[agent] = tf.reduce_mean(policy_loss, axis=0)
        self.tape = tape
Beispiel #5
0
    def _step(self) -> Dict[str, tf.Tensor]:
        """Does an SGD step on a batch of sequences."""

        # Retrieve a batch of data from replay.
        inputs: reverb.ReplaySample = next(self._iterator)
        data = tf2_utils.batch_to_sequence(inputs.data)
        observations, actions, rewards, discounts, extra = (data.observation,
                                                            data.action,
                                                            data.reward,
                                                            data.discount,
                                                            data.extras)
        core_state = tree.map_structure(lambda s: s[0], extra['core_state'])

        #
        actions = actions[:-1]  # [T-1]
        rewards = rewards[:-1]  # [T-1]
        discounts = discounts[:-1]  # [T-1]

        with tf.GradientTape() as tape:
            # Unroll current policy over observations.
            (logits, values), _ = snt.static_unroll(self._network,
                                                    observations, core_state)

            # Compute importance sampling weights: current policy / behavior policy.
            behaviour_logits = extra['logits']
            pi_behaviour = tfd.Categorical(logits=behaviour_logits[:-1])
            pi_target = tfd.Categorical(logits=logits[:-1])
            log_rhos = pi_target.log_prob(actions) - pi_behaviour.log_prob(
                actions)

            # Optionally clip rewards.
            rewards = tf.clip_by_value(
                rewards, tf.cast(-self._max_abs_reward, rewards.dtype),
                tf.cast(self._max_abs_reward, rewards.dtype))

            # Critic loss.
            vtrace_returns = trfl.vtrace_from_importance_weights(
                log_rhos=tf.cast(log_rhos, tf.float32),
                discounts=tf.cast(self._discount * discounts, tf.float32),
                rewards=tf.cast(rewards, tf.float32),
                values=tf.cast(values[:-1], tf.float32),
                bootstrap_value=values[-1],
            )
            critic_loss = tf.square(vtrace_returns.vs - values[:-1])

            # Policy-gradient loss.
            policy_gradient_loss = trfl.policy_gradient(
                policies=pi_target,
                actions=actions,
                action_values=vtrace_returns.pg_advantages,
            )

            # Entropy regulariser.
            entropy_loss = trfl.policy_entropy_loss(pi_target).loss

            # Combine weighted sum of actor & critic losses.
            loss = tf.reduce_mean(policy_gradient_loss +
                                  self._baseline_cost * critic_loss +
                                  self._entropy_cost * entropy_loss)

        # Compute gradients and optionally apply clipping.
        gradients = tape.gradient(loss, self._network.trainable_variables)
        gradients, _ = tf.clip_by_global_norm(gradients,
                                              self._max_gradient_norm)
        self._optimizer.apply(gradients, self._network.trainable_variables)

        metrics = {
            'loss': loss,
            'critic_loss': tf.reduce_mean(critic_loss),
            'entropy_loss': tf.reduce_mean(entropy_loss),
            'policy_gradient_loss': tf.reduce_mean(policy_gradient_loss),
        }

        return metrics
Beispiel #6
0
    def _forward(self, inputs: Any) -> None:
        data = tree.map_structure(
            lambda v: tf.expand_dims(v, axis=0)
            if len(v.shape) <= 1 else v, inputs.data)
        data = tf2_utils.batch_to_sequence(data)

        observations, actions, rewards, discounts, _, extra = data

        core_state = tree.map_structure(lambda s: s[:, 0, :],
                                        inputs.data.extras["core_states"])
        core_message = tree.map_structure(lambda s: s[:, 0, :],
                                          inputs.data.extras["core_messages"])
        T = actions[self._agents[0]].shape[0]

        # Use fact that end of episode always has the reward to
        # find episode lengths. This is used to mask loss.
        ep_end = tf.argmax(tf.math.abs(rewards[self._agents[0]]), axis=0)

        with tf.GradientTape(persistent=True) as tape:
            q_network_losses: Dict[str, NestedArray] = {
                agent: {
                    "q_value_loss": tf.zeros(())
                }
                for agent in self._agents
            }

            state = {agent: core_state[agent][0] for agent in self._agents}
            target_state = {
                agent: core_state[agent][0]
                for agent in self._agents
            }

            message = {agent: core_message[agent][0] for agent in self._agents}
            target_message = {
                agent: core_message[agent][0]
                for agent in self._agents
            }

            # _target_q_networks must be 1 step ahead
            target_channel = self._communication_module.process_messages(
                target_message)
            for agent in self._agents:
                agent_key = self.agent_net_keys[agent]
                (q_targ, m), s = self._target_q_networks[agent_key](
                    observations[agent].observation[0],
                    target_state[agent],
                    target_channel[agent],
                )
                target_state[agent] = s
                target_message[agent] = m

            for t in range(1, T, 1):
                channel = self._communication_module.process_messages(message)
                target_channel = self._communication_module.process_messages(
                    target_message)

                for agent in self._agents:
                    agent_key = self.agent_net_keys[agent]

                    # Cast the additional discount
                    # to match the environment discount dtype.

                    discount = tf.cast(self._discount,
                                       dtype=discounts[agent][0].dtype)

                    (q_targ, m), s = self._target_q_networks[agent_key](
                        observations[agent].observation[t],
                        target_state[agent],
                        target_channel[agent],
                    )

                    target_state[agent] = s
                    target_message[agent] = tf.math.multiply(
                        m, observations[agent].observation[t][:, :1])

                    (q, m), s = self._q_networks[agent_key](
                        observations[agent].observation[t - 1],
                        state[agent],
                        channel[agent],
                    )

                    state[agent] = s
                    message[agent] = tf.math.multiply(
                        m, observations[agent].observation[t - 1][:, :1])

                    # Mask target
                    q_targ = tf.concat(
                        [[q_targ[i]]
                         if t <= ep_end[i] else [tf.zeros_like(q_targ[i])]
                         for i in range(q_targ.shape[0])],
                        axis=0,
                    )

                    loss, _ = trfl.qlearning(
                        q,
                        actions[agent][t - 1],
                        rewards[agent][t - 1],
                        discount * discounts[agent][t],
                        q_targ,
                    )

                    # Index loss (mask ended episodes)
                    if not tf.reduce_any(t - 1 <= ep_end):
                        continue

                    loss = tf.reduce_mean(loss[t - 1 <= ep_end])
                    # loss = tf.reduce_mean(loss)
                    q_network_losses[agent]["q_value_loss"] += loss

        self._q_network_losses = q_network_losses
        self.tape = tape
Beispiel #7
0
    def _forward(self, inputs: Any) -> None:
        """Trainer forward pass

        Args:
            inputs (Any): input data from the data table (transitions)
        """

        # Convert to sequence data
        data = tf2_utils.batch_to_sequence(inputs.data)

        # Unpack input data as follows:
        observations, actions, rewards, discounts, extras = (
            data.observations,
            data.actions,
            data.rewards,
            data.discounts,
            data.extras,
        )

        # transform observation using observation networks
        observations_trans = self._transform_observations(observations)

        # Get log_probs.
        log_probs = extras["log_probs"]

        # Store losses.
        policy_losses: Dict[str, Any] = {}
        critic_losses: Dict[str, Any] = {}

        with tf.GradientTape(persistent=True) as tape:
            for agent in self._agents:

                action, reward, discount, behaviour_log_prob = (
                    actions[agent],
                    rewards[agent],
                    discounts[agent],
                    log_probs[agent],
                )

                actor_observation = observations_trans[agent]
                critic_observation = self._get_critic_feed(
                    observations_trans, agent)

                # Chop off final timestep for bootstrapping value
                reward = reward[:-1]
                discount = discount[:-1]

                # Get agent network
                agent_key = agent.split(
                    "_")[0] if self._shared_weights else agent
                policy_network = self._policy_networks[agent_key]
                critic_network = self._critic_networks[agent_key]

                # Reshape inputs.
                dims = actor_observation.shape[:2]
                actor_observation = snt.merge_leading_dims(actor_observation,
                                                           num_dims=2)
                critic_observation = snt.merge_leading_dims(critic_observation,
                                                            num_dims=2)
                policy = policy_network(actor_observation)
                values = critic_network(critic_observation)

                # Reshape the outputs.
                policy = tfd.BatchReshape(policy,
                                          batch_shape=dims,
                                          name="policy")
                values = tf.reshape(values, dims, name="value")

                # Values along the sequence T.
                bootstrap_value = values[-1]
                state_values = values[:-1]

                # Generalized Return Estimation
                td_loss, td_lambda_extra = trfl.td_lambda(
                    state_values=state_values,
                    rewards=reward,
                    pcontinues=discount,
                    bootstrap_value=bootstrap_value,
                    lambda_=self._lambda_gae,
                    name="CriticLoss",
                )

                # Do not use the loss provided by td_lambda as they sum the losses over
                # the sequence length rather than averaging them.
                critic_loss = self._baseline_cost * tf.reduce_mean(
                    tf.square(td_lambda_extra.temporal_differences),
                    name="CriticLoss")

                # Compute importance sampling weights: current policy / behavior policy.
                log_rhos = policy.log_prob(action) - behaviour_log_prob
                importance_ratio = tf.exp(log_rhos)[:-1]
                clipped_importance_ratio = tf.clip_by_value(
                    importance_ratio,
                    1.0 - self._clipping_epsilon,
                    1.0 + self._clipping_epsilon,
                )

                # Generalized Advantage Estimation
                gae = tf.stop_gradient(td_lambda_extra.temporal_differences)
                mean, variance = tf.nn.moments(gae, axes=[0, 1], keepdims=True)
                normalized_gae = (gae - mean) / tf.sqrt(variance)

                policy_gradient_loss = tf.reduce_mean(
                    -tf.minimum(
                        tf.multiply(importance_ratio, normalized_gae),
                        tf.multiply(clipped_importance_ratio, normalized_gae),
                    ),
                    name="PolicyGradientLoss",
                )

                # Entropy regularization. Only implemented for categorical dist.
                try:
                    policy_entropy = tf.reduce_mean(policy.entropy())
                except NotImplementedError:
                    policy_entropy = tf.convert_to_tensor(0.0)

                entropy_loss = -self._entropy_cost * policy_entropy

                # Combine weighted sum of actor & entropy regularization.
                policy_loss = policy_gradient_loss + entropy_loss

                policy_losses[agent] = policy_loss
                critic_losses[agent] = critic_loss

        self.policy_losses = policy_losses
        self.critic_losses = critic_losses
        self.tape = tape