Пример #1
0
    def __call__(self, inputs: types.NestedTensor) -> tf.Tensor:
        # Inputs are of size [B, ...]. Here we tile them to be of shape [N, B, ...].
        tiled_inputs = tf2_utils.tile_nested(inputs, self._num_action_samples)
        shape = tf.shape(tree.flatten(tiled_inputs)[0])
        n, b = shape[0], shape[1]
        tf.debugging.assert_equal(
            n, self._num_action_samples,
            'Internal Error. Unexpected tiled_inputs shape.')
        dummy_zeros_n_b = tf.zeros((n, b))
        # Reshape to [N * B, ...].
        merge = lambda x: snt.merge_leading_dims(x, 2)
        tiled_inputs = tree.map_structure(merge, tiled_inputs)

        tiled_actions = self._actor_network(tiled_inputs)

        # Compute Q-values and the resulting tempered probabilities.
        q = self._critic_network(tiled_inputs, tiled_actions)
        boltzmann_logits = q / self._beta

        boltzmann_logits = snt.split_leading_dim(boltzmann_logits,
                                                 dummy_zeros_n_b, 2)
        # [B, N]
        boltzmann_logits = tf.transpose(boltzmann_logits, perm=(1, 0))
        # Resample one action per batch according to the Boltzmann distribution.
        action_idx = tfp.distributions.Categorical(
            logits=boltzmann_logits).sample()
        # [B, 2], where the first column is 0, 1, 2,... corresponding to indices to
        # the batch dimension.
        action_idx = tf.stack((tf.range(b), action_idx), axis=1)

        tiled_actions = snt.split_leading_dim(tiled_actions, dummy_zeros_n_b,
                                              2)
        action_dim = len(tiled_actions.get_shape().as_list())
        tiled_actions = tf.transpose(tiled_actions,
                                     perm=[1, 0] + list(range(2, action_dim)))
        # [B, ...]
        action_sample = tf.gather_nd(tiled_actions, action_idx)

        return action_sample
Пример #2
0
    def _step(self) -> types.NestedTensor:
        # Update target network.
        online_policy_variables = self._policy_network.variables
        target_policy_variables = self._target_policy_network.variables
        online_critic_variables = (
            *self._observation_network.variables,
            *self._critic_network.variables,
        )
        target_critic_variables = (
            *self._target_observation_network.variables,
            *self._target_critic_network.variables,
        )

        # Make online policy -> target policy network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_policy_update_period) == 0:
            for src, dest in zip(online_policy_variables,
                                 target_policy_variables):
                dest.assign(src)
        # Make online critic -> target critic network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_critic_update_period) == 0:
            for src, dest in zip(online_critic_variables,
                                 target_critic_variables):
                dest.assign(src)

        self._num_steps.assign_add(1)

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t = inputs.data

        # Get batch size and scalar dtype.
        batch_size = r_t.shape[0]

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=d_t.dtype)

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(o_tm1)
            # This stop_gradient prevents gradients to propagate into the target
            # observation network. In addition, since the online policy network is
            # evaluated at o_t, this also means the policy loss does not influence
            # the observation network training.
            o_t = tf.stop_gradient(self._target_observation_network(o_t))

            # Get online and target action distributions from policy networks.
            online_action_distribution = self._policy_network(o_t)
            target_action_distribution = self._target_policy_network(o_t)

            # Sample actions to evaluate policy; of size [N, B, ...].
            sampled_actions = target_action_distribution.sample(
                self._num_samples)

            # Tile embedded observations to feed into the target critic network.
            # Note: this is more efficient than tiling before the embedding layer.
            tiled_o_t = tf2_utils.tile_tensor(o_t,
                                              self._num_samples)  # [N, B, ...]

            # Compute target-estimated distributional value of sampled actions at o_t.
            sampled_q_t_distributions = self._target_critic_network(
                # Merge batch dimensions; to shape [N*B, ...].
                snt.merge_leading_dims(tiled_o_t, num_dims=2),
                snt.merge_leading_dims(sampled_actions, num_dims=2))

            # Compute average logits by first reshaping them and normalizing them
            # across atoms.
            new_shape = [self._num_samples, batch_size, -1]  # [N, B, A]
            sampled_logits = tf.reshape(sampled_q_t_distributions.logits,
                                        new_shape)
            sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1)
            averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0)

            # Construct the expected distributional value for bootstrapping.
            q_t_distribution = networks.DiscreteValuedDistribution(
                values=sampled_q_t_distributions.values,
                logits=averaged_logits)

            # Compute online critic value distribution of a_tm1 in state o_tm1.
            q_tm1_distribution = self._critic_network(o_tm1, a_tm1)

            # Compute critic distributional loss.
            critic_loss = losses.categorical(q_tm1_distribution, r_t,
                                             discount * d_t, q_t_distribution)
            critic_loss = tf.reduce_mean(critic_loss)

            # Compute Q-values of sampled actions and reshape to [N, B].
            sampled_q_values = sampled_q_t_distributions.mean()
            sampled_q_values = tf.reshape(sampled_q_values,
                                          (self._num_samples, -1))

            # Compute MPO policy loss.
            policy_loss, policy_stats = self._policy_loss_module(
                online_action_distribution=online_action_distribution,
                target_action_distribution=target_action_distribution,
                actions=sampled_actions,
                q_values=sampled_q_values)

        # For clarity, explicitly define which variables are trained by which loss.
        critic_trainable_variables = (
            # In this agent, the critic loss trains the observation network.
            self._observation_network.trainable_variables +
            self._critic_network.trainable_variables)
        policy_trainable_variables = self._policy_network.trainable_variables
        # The following are the MPO dual variables, stored in the loss module.
        dual_trainable_variables = self._policy_loss_module.trainable_variables

        # Compute gradients.
        critic_gradients = tape.gradient(critic_loss,
                                         critic_trainable_variables)
        policy_gradients, dual_gradients = tape.gradient(
            policy_loss,
            (policy_trainable_variables, dual_trainable_variables))

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tuple(
                tf.clip_by_global_norm(policy_gradients, 40.)[0])
            critic_gradients = tuple(
                tf.clip_by_global_norm(critic_gradients, 40.)[0])

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     critic_trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     policy_trainable_variables)
        self._dual_optimizer.apply(dual_gradients, dual_trainable_variables)

        # Losses to track.
        fetches = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }
        fetches.update(policy_stats)  # Log MPO stats.

        return fetches
Пример #3
0
    def _step(self) -> types.NestedTensor:
        # Update target network.
        online_policy_variables = self._policy_network.variables
        target_policy_variables = self._target_policy_network.variables
        online_critic_variables = (
            *self._observation_network.variables,
            *self._critic_network.variables,
        )
        target_critic_variables = (
            *self._target_observation_network.variables,
            *self._target_critic_network.variables,
        )

        # Make online policy -> target policy network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_policy_update_period) == 0:
            for src, dest in zip(online_policy_variables,
                                 target_policy_variables):
                dest.assign(src)
        # Make online critic -> target critic network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_critic_update_period) == 0:
            for src, dest in zip(online_critic_variables,
                                 target_critic_variables):
                dest.assign(src)

        self._num_steps.assign_add(1)

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t = inputs.data

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(o_tm1)
            # This stop_gradient prevents gradients to propagate into the target
            # observation network. In addition, since the online policy network is
            # evaluated at o_t, this also means the policy loss does not influence
            # the observation network training.
            o_t = tf.stop_gradient(self._target_observation_network(o_t))

            # Get online and target action distributions from policy networks.
            online_action_distribution = self._policy_network(o_t)
            target_action_distribution = self._target_policy_network(o_t)

            # Sample actions to evaluate policy; of size [N, B, ...].
            sampled_actions = target_action_distribution.sample(
                self._num_samples)

            # Tile embedded observations to feed into the target critic network.
            # Note: this is more efficient than tiling before the embedding layer.
            tiled_o_t = tf2_utils.tile_tensor(o_t,
                                              self._num_samples)  # [N, B, ...]

            # Compute target-estimated distributional value of sampled actions at o_t.
            sampled_q_t_all = self._target_critic_network(
                # Merge batch dimensions; to shape [N*B, ...].
                snt.merge_leading_dims(tiled_o_t, num_dims=2),
                snt.merge_leading_dims(sampled_actions, num_dims=2))

            # Compute online critic value distribution of a_tm1 in state o_tm1.
            q_tm1_all = self._critic_network(o_tm1, a_tm1)

            # Compute rewards for objectives with defined reward_fn
            reward_stats = {}
            r_t_all = []
            for objective in self._reward_objectives:
                r = objective.reward_fn(o_tm1, a_tm1, r_t)
                reward_stats['{}_reward'.format(
                    objective.name)] = tf.reduce_mean(r)
                r_t_all.append(r)
            r_t_all = tf.stack(r_t_all, axis=-1)
            r_t_all.get_shape().assert_has_rank(2)  # [B, C]

            if isinstance(sampled_q_t_all, list):  # Distributional critics
                critic_loss, sampled_q_t = _compute_distributional_critic_loss(
                    sampled_q_t_all, q_tm1_all, r_t_all, d_t, self._discount,
                    self._num_samples)
            else:
                critic_loss, sampled_q_t = _compute_critic_loss(
                    sampled_q_t_all, q_tm1_all, r_t_all, d_t, self._discount,
                    self._num_samples, self._num_critic_heads)

            # Add sampled Q-values for objectives with defined qvalue_fn
            sampled_q_t_k = [sampled_q_t]
            for objective in self._qvalue_objectives:
                sampled_q_t_k.append(
                    tf.expand_dims(tf.stop_gradient(
                        objective.qvalue_fn(sampled_actions, sampled_q_t)),
                                   axis=-1))
            sampled_q_t_k = tf.concat(sampled_q_t_k, axis=-1)  # [N, B, K]

            # Compute MPO policy loss.
            policy_loss, policy_stats = self._policy_loss_module(
                online_action_distribution=online_action_distribution,
                target_action_distribution=target_action_distribution,
                actions=sampled_actions,
                q_values=sampled_q_t_k)

        # For clarity, explicitly define which variables are trained by which loss.
        critic_trainable_variables = (
            # In this agent, the critic loss trains the observation network.
            self._observation_network.trainable_variables +
            self._critic_network.trainable_variables)
        policy_trainable_variables = self._policy_network.trainable_variables
        # The following are the MPO dual variables, stored in the loss module.
        dual_trainable_variables = self._policy_loss_module.trainable_variables

        # Compute gradients.
        critic_gradients = tape.gradient(critic_loss,
                                         critic_trainable_variables)
        policy_gradients, dual_gradients = tape.gradient(
            policy_loss,
            (policy_trainable_variables, dual_trainable_variables))

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tuple(
                tf.clip_by_global_norm(policy_gradients, 40.)[0])
            critic_gradients = tuple(
                tf.clip_by_global_norm(critic_gradients, 40.)[0])

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     critic_trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     policy_trainable_variables)
        self._dual_optimizer.apply(dual_gradients, dual_trainable_variables)

        # Losses to track.
        fetches = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }
        fetches.update(policy_stats)  # Log MPO stats.
        fetches.update(reward_stats)  # Log reward stats.

        return fetches
Пример #4
0
    def _step(self) -> types.NestedTensor:
        # Update target network.
        online_policy_variables = self._policy_network.variables
        target_policy_variables = self._target_policy_network.variables
        online_critic_variables = (
            *self._observation_network.variables,
            *self._critic_network.variables,
        )
        target_critic_variables = (
            *self._target_observation_network.variables,
            *self._target_critic_network.variables,
        )

        # Make online policy -> target policy network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_policy_update_period) == 0:
            for src, dest in zip(online_policy_variables,
                                 target_policy_variables):
                dest.assign(src)
        # Make online critic -> target critic network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_critic_update_period) == 0:
            for src, dest in zip(online_critic_variables,
                                 target_critic_variables):
                dest.assign(src)

        self._num_steps.assign_add(1)

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t = inputs.data

        # Get batch size and scalar dtype.
        batch_size = r_t.shape[0]

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=d_t.dtype)

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(o_tm1)
            # This stop_gradient prevents gradients to propagate into the target
            # observation network. In addition, since the online policy network is
            # evaluated at o_t, this also means the policy loss does not influence
            # the observation network training.
            o_t = tf.stop_gradient(self._target_observation_network(o_t))

            # Get online and target action distributions from policy networks.
            online_action_distribution = self._policy_network(o_t)
            target_action_distribution = self._target_policy_network(o_t)

            # Sample actions to evaluate policy; of size [N, B, ...].
            sampled_actions = target_action_distribution.sample(
                self._num_samples)

            # Tile embedded observations to feed into the target critic network.
            # Note: this is more efficient than tiling before the embedding layer.
            tiled_o_t = tf2_utils.tile_tensor(o_t,
                                              self._num_samples)  # [N, B, ...]

            # Compute target-estimated distributional value of sampled actions at o_t.
            sampled_q_t_all = self._target_critic_network(
                # Merge batch dimensions; to shape [N*B, ...].
                snt.merge_leading_dims(tiled_o_t, num_dims=2),
                snt.merge_leading_dims(sampled_actions, num_dims=2))

            # Compute online critic value distribution of a_tm1 in state o_tm1.
            q_tm1_all = self._critic_network(o_tm1, a_tm1)

            # Compute rewards for objectives with defined reward_fn
            reward_stats = {}
            r_t_all = []
            for objective in self._objectives:
                if hasattr(objective, 'reward_fn'):
                    r = objective.reward_fn(o_tm1, a_tm1, r_t)
                    reward_stats['{}_reward'.format(
                        objective.name)] = tf.reduce_mean(r)
                    r_t_all.append(r)
            r_t_all = tf.stack(r_t_all, axis=-1)
            r_t_all.get_shape().assert_has_rank(2)  # [B, C]

            if isinstance(sampled_q_t_all, list):  # Distributional critics
                # Compute average logits by first reshaping them and normalizing them
                # across atoms.
                critic_losses = []
                sampled_q_ts = []
                for idx, (sampled_q_t_distributions,
                          q_tm1_distribution) in enumerate(
                              zip(sampled_q_t_all, q_tm1_all)):
                    # Compute loss for distributional critic for objective c
                    sampled_logits = tf.reshape(
                        sampled_q_t_distributions.logits,
                        [self._num_samples, batch_size, -1])  # [N, B, A]
                    sampled_logprobs = tf.math.log_softmax(sampled_logits,
                                                           axis=-1)
                    averaged_logits = tf.reduce_logsumexp(sampled_logprobs,
                                                          axis=0)

                    # Construct the expected distributional value for bootstrapping.
                    q_t_distribution = networks.DiscreteValuedDistribution(
                        values=sampled_q_t_distributions.values,
                        logits=averaged_logits)

                    # Compute critic distributional loss.
                    critic_loss = losses.categorical(q_tm1_distribution,
                                                     r_t_all[:, idx],
                                                     discount * d_t,
                                                     q_t_distribution)
                    critic_losses.append(tf.reduce_mean(critic_loss))

                    # Compute Q-values of sampled actions and reshape to [N, B].
                    sampled_q_ts.append(
                        tf.reshape(sampled_q_t_distributions.mean(),
                                   (self._num_samples, -1)))

                critic_loss = tf.reduce_mean(critic_losses)
                sampled_q_t = tf.stack(sampled_q_ts, axis=-1)  # [N, B, C]
            else:
                # Reshape Q-value samples back to original batch dimensions and average
                # them to compute the TD-learning bootstrap target.
                sampled_q_t = tf.reshape(sampled_q_t_all,
                                         (self._num_samples, batch_size,
                                          self._num_critic_heads))  # [N,B,C]
                q_t = tf.reduce_mean(sampled_q_t, axis=0)  # [B, C]

                # Flatten q_t and q_tm1; necessary for trfl.td_learning
                q_t = tf.reshape(q_t, [-1])  # [B*C]
                q_tm1 = tf.reshape(q_tm1_all, [-1])  # [B*C]

                # Flatten r_t_all; necessary for trfl.td_learning
                r_t_all = tf.reshape(r_t_all, [-1])  # [B*C]

                # Broadcast and then flatten d_t, to match shape of q_t and q_tm1
                d_t = tf.tile(d_t, [self._num_critic_heads])  # [B*C]

                # Critic loss.
                critic_loss = trfl.td_learning(q_tm1, r_t_all, discount * d_t,
                                               q_t).loss
                critic_loss = tf.reduce_mean(critic_loss)

            # Add sampled Q-values for objectives with defined objective_fn
            sampled_q_idx = 0
            sampled_q_t_k = []
            for objective in self._objectives:
                if hasattr(objective, 'reward_fn'):
                    sampled_q_t_k.append(
                        tf.stop_gradient(sampled_q_t[..., sampled_q_idx]))
                    sampled_q_idx += 1
                if hasattr(objective, 'objective_fn'):
                    sampled_q_t_k.append(
                        tf.stop_gradient(
                            objective.objective_fn(sampled_actions,
                                                   sampled_q_t)))
            sampled_q_t_k = tf.stack(sampled_q_t_k, axis=-1)  # [N, B, K]

            # Compute MPO policy loss.
            policy_loss, policy_stats = self._policy_loss_module(
                online_action_distribution=online_action_distribution,
                target_action_distribution=target_action_distribution,
                actions=sampled_actions,
                q_values=sampled_q_t_k)

        # For clarity, explicitly define which variables are trained by which loss.
        critic_trainable_variables = (
            # In this agent, the critic loss trains the observation network.
            self._observation_network.trainable_variables +
            self._critic_network.trainable_variables)
        policy_trainable_variables = self._policy_network.trainable_variables
        # The following are the MPO dual variables, stored in the loss module.
        dual_trainable_variables = self._policy_loss_module.trainable_variables

        # Compute gradients.
        critic_gradients = tape.gradient(critic_loss,
                                         critic_trainable_variables)
        policy_gradients, dual_gradients = tape.gradient(
            policy_loss,
            (policy_trainable_variables, dual_trainable_variables))

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tuple(
                tf.clip_by_global_norm(policy_gradients, 40.)[0])
            critic_gradients = tuple(
                tf.clip_by_global_norm(critic_gradients, 40.)[0])

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     critic_trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     policy_trainable_variables)
        self._dual_optimizer.apply(dual_gradients, dual_trainable_variables)

        # Losses to track.
        fetches = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }
        fetches.update(policy_stats)  # Log MPO stats.
        fetches.update(reward_stats)  # Log reward stats.

        return fetches
Пример #5
0
    def _step(self, inputs: reverb.ReplaySample) -> types.NestedTensor:

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        o_tm1, a_tm1, r_t, d_t, o_t = (inputs.data.observation,
                                       inputs.data.action, inputs.data.reward,
                                       inputs.data.discount,
                                       inputs.data.next_observation)

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=d_t.dtype)

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(o_tm1)
            # This stop_gradient prevents gradients to propagate into the target
            # observation network. In addition, since the online policy network is
            # evaluated at o_t, this also means the policy loss does not influence
            # the observation network training.
            o_t = tf.stop_gradient(self._target_observation_network(o_t))

            # Get online and target action distributions from policy networks.
            online_action_distribution = self._policy_network(o_t)
            target_action_distribution = self._target_policy_network(o_t)

            # Sample actions to evaluate policy; of size [N, B, ...].
            sampled_actions = target_action_distribution.sample(
                self._num_samples)

            # Tile embedded observations to feed into the target critic network.
            # Note: this is more efficient than tiling before the embedding layer.
            tiled_o_t = tf2_utils.tile_tensor(o_t,
                                              self._num_samples)  # [N, B, ...]

            # Compute target-estimated distributional value of sampled actions at o_t.
            sampled_q_t_distributions = self._target_critic_network(
                # Merge batch dimensions; to shape [N*B, ...].
                snt.merge_leading_dims(tiled_o_t, num_dims=2),
                snt.merge_leading_dims(sampled_actions, num_dims=2))

            # Compute online critic value distribution of a_tm1 in state o_tm1.
            q_tm1_distribution = self._critic_network(o_tm1, a_tm1)  # [B, ...]

            # Get the return distributions used in the policy evaluation bootstrap.
            if self._policy_evaluation_config.evaluate_stochastic_policy:
                z_distributions = sampled_q_t_distributions
                num_joint_samples = self._num_samples
            else:
                z_distributions = self._target_critic_network(
                    o_t, target_action_distribution.mean())
                num_joint_samples = 1

            num_value_samples = self._policy_evaluation_config.num_value_samples
            num_joint_samples *= num_value_samples
            z_samples = z_distributions.sample(num_value_samples)
            z_samples = tf.reshape(z_samples, (num_joint_samples, -1, 1))

            # Expand dims of reward and discount tensors.
            reward = r_t[..., tf.newaxis]  # [B, 1]
            full_discount = discount * d_t[..., tf.newaxis]
            target_q = reward + full_discount * z_samples  # [N, B, 1]
            target_q = tf.stop_gradient(target_q)

            # Compute sample-based cross-entropy.
            log_probs_q = q_tm1_distribution.log_prob(target_q)  # [N, B, 1]
            critic_loss = -tf.reduce_mean(log_probs_q, axis=0)  # [B, 1]
            critic_loss = tf.reduce_mean(critic_loss)

            # Compute Q-values of sampled actions and reshape to [N, B].
            sampled_q_values = sampled_q_t_distributions.mean()
            sampled_q_values = tf.reshape(sampled_q_values,
                                          (self._num_samples, -1))

            # Compute MPO policy loss.
            policy_loss, policy_stats = self._policy_loss_module(
                online_action_distribution=online_action_distribution,
                target_action_distribution=target_action_distribution,
                actions=sampled_actions,
                q_values=sampled_q_values)
            policy_loss = tf.reduce_mean(policy_loss)

        # For clarity, explicitly define which variables are trained by which loss.
        critic_trainable_variables = (
            # In this agent, the critic loss trains the observation network.
            self._observation_network.trainable_variables +
            self._critic_network.trainable_variables)
        policy_trainable_variables = self._policy_network.trainable_variables
        # The following are the MPO dual variables, stored in the loss module.
        dual_trainable_variables = self._policy_loss_module.trainable_variables

        # Compute gradients.
        critic_gradients = tape.gradient(critic_loss,
                                         critic_trainable_variables)
        policy_gradients, dual_gradients = tape.gradient(
            policy_loss,
            (policy_trainable_variables, dual_trainable_variables))

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tuple(
                tf.clip_by_global_norm(policy_gradients, 40.)[0])
            critic_gradients = tuple(
                tf.clip_by_global_norm(critic_gradients, 40.)[0])

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     critic_trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     policy_trainable_variables)
        self._dual_optimizer.apply(dual_gradients, dual_trainable_variables)

        # Losses to track.
        fetches = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }
        # Log MPO stats.
        fetches.update(policy_stats)

        return fetches
Пример #6
0
    def _forward(self, inputs: Any) -> None:
        """Trainer forward pass

        Args:
            inputs (Any): input data from the data table (transitions)
        """

        # Convert to sequence data
        data = tf2_utils.batch_to_sequence(inputs.data)

        # Unpack input data as follows:
        observations, actions, rewards, discounts, extras = (
            data.observations,
            data.actions,
            data.rewards,
            data.discounts,
            data.extras,
        )

        # transform observation using observation networks
        observations_trans = self._transform_observations(observations)

        # Get log_probs.
        log_probs = extras["log_probs"]

        # Store losses.
        policy_losses: Dict[str, Any] = {}
        critic_losses: Dict[str, Any] = {}

        with tf.GradientTape(persistent=True) as tape:
            for agent in self._agents:

                action, reward, discount, behaviour_log_prob = (
                    actions[agent],
                    rewards[agent],
                    discounts[agent],
                    log_probs[agent],
                )

                actor_observation = observations_trans[agent]
                critic_observation = self._get_critic_feed(
                    observations_trans, agent)

                # Chop off final timestep for bootstrapping value
                reward = reward[:-1]
                discount = discount[:-1]

                # Get agent network
                agent_key = agent.split(
                    "_")[0] if self._shared_weights else agent
                policy_network = self._policy_networks[agent_key]
                critic_network = self._critic_networks[agent_key]

                # Reshape inputs.
                dims = actor_observation.shape[:2]
                actor_observation = snt.merge_leading_dims(actor_observation,
                                                           num_dims=2)
                critic_observation = snt.merge_leading_dims(critic_observation,
                                                            num_dims=2)
                policy = policy_network(actor_observation)
                values = critic_network(critic_observation)

                # Reshape the outputs.
                policy = tfd.BatchReshape(policy,
                                          batch_shape=dims,
                                          name="policy")
                values = tf.reshape(values, dims, name="value")

                # Values along the sequence T.
                bootstrap_value = values[-1]
                state_values = values[:-1]

                # Generalized Return Estimation
                td_loss, td_lambda_extra = trfl.td_lambda(
                    state_values=state_values,
                    rewards=reward,
                    pcontinues=discount,
                    bootstrap_value=bootstrap_value,
                    lambda_=self._lambda_gae,
                    name="CriticLoss",
                )

                # Do not use the loss provided by td_lambda as they sum the losses over
                # the sequence length rather than averaging them.
                critic_loss = self._baseline_cost * tf.reduce_mean(
                    tf.square(td_lambda_extra.temporal_differences),
                    name="CriticLoss")

                # Compute importance sampling weights: current policy / behavior policy.
                log_rhos = policy.log_prob(action) - behaviour_log_prob
                importance_ratio = tf.exp(log_rhos)[:-1]
                clipped_importance_ratio = tf.clip_by_value(
                    importance_ratio,
                    1.0 - self._clipping_epsilon,
                    1.0 + self._clipping_epsilon,
                )

                # Generalized Advantage Estimation
                gae = tf.stop_gradient(td_lambda_extra.temporal_differences)
                mean, variance = tf.nn.moments(gae, axes=[0, 1], keepdims=True)
                normalized_gae = (gae - mean) / tf.sqrt(variance)

                policy_gradient_loss = tf.reduce_mean(
                    -tf.minimum(
                        tf.multiply(importance_ratio, normalized_gae),
                        tf.multiply(clipped_importance_ratio, normalized_gae),
                    ),
                    name="PolicyGradientLoss",
                )

                # Entropy regularization. Only implemented for categorical dist.
                try:
                    policy_entropy = tf.reduce_mean(policy.entropy())
                except NotImplementedError:
                    policy_entropy = tf.convert_to_tensor(0.0)

                entropy_loss = -self._entropy_cost * policy_entropy

                # Combine weighted sum of actor & entropy regularization.
                policy_loss = policy_gradient_loss + entropy_loss

                policy_losses[agent] = policy_loss
                critic_losses[agent] = critic_loss

        self.policy_losses = policy_losses
        self.critic_losses = critic_losses
        self.tape = tape
Пример #7
0
    def _step(self) -> types.Nest:
        # Update target network.
        online_policy_variables = self._policy_network.variables
        target_policy_variables = self._target_policy_network.variables
        online_critic_variables = (
            *self._observation_network.variables,
            *self._critic_network.variables,
        )
        target_critic_variables = (
            *self._target_observation_network.variables,
            *self._target_critic_network.variables,
        )

        # Make online policy -> target policy network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_policy_update_period) == 0:
            for src, dest in zip(online_policy_variables,
                                 target_policy_variables):
                dest.assign(src)
        # Make online critic -> target critic network update ops.
        if tf.math.mod(self._num_steps,
                       self._target_critic_update_period) == 0:
            for src, dest in zip(online_critic_variables,
                                 target_critic_variables):
                dest.assign(src)

        # Increment number of learner steps for periodic update bookkeeping.
        self._num_steps.assign_add(1)

        # Get next batch of data.
        inputs = next(self._iterator)

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        transitions: types.Transition = inputs.data

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=transitions.discount.dtype)

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(transitions.observation)
            # This stop_gradient prevents gradients to propagate into the target
            # observation network. In addition, since the online policy network is
            # evaluated at o_t, this also means the policy loss does not influence
            # the observation network training.
            o_t = tf.stop_gradient(
                self._target_observation_network(transitions.next_observation))

            # Get action distributions from policy networks.
            online_action_distribution = self._policy_network(o_t)
            target_action_distribution = self._target_policy_network(o_t)

            # Get sampled actions to evaluate policy; of size [N, B, ...].
            sampled_actions = target_action_distribution.sample(
                self._num_samples)
            tiled_o_t = tf2_utils.tile_tensor(o_t,
                                              self._num_samples)  # [N, B, ...]

            # Compute the target critic's Q-value of the sampled actions in state o_t.
            sampled_q_t = self._target_critic_network(
                # Merge batch dimensions; to shape [N*B, ...].
                snt.merge_leading_dims(tiled_o_t, num_dims=2),
                snt.merge_leading_dims(sampled_actions, num_dims=2))

            # Reshape Q-value samples back to original batch dimensions and average
            # them to compute the TD-learning bootstrap target.
            sampled_q_t = tf.reshape(sampled_q_t,
                                     (self._num_samples, -1))  # [N, B]
            q_t = tf.reduce_mean(sampled_q_t, axis=0)  # [B]

            # Compute online critic value of a_tm1 in state o_tm1.
            q_tm1 = self._critic_network(o_tm1, transitions.action)  # [B, 1]
            q_tm1 = tf.squeeze(q_tm1,
                               axis=-1)  # [B]; necessary for trfl.td_learning.

            # Critic loss.
            critic_loss = trfl.td_learning(q_tm1, transitions.reward,
                                           discount * transitions.discount,
                                           q_t).loss
            critic_loss = tf.reduce_mean(critic_loss)

            # Actor learning.
            policy_loss, policy_stats = self._policy_loss_module(
                online_action_distribution=online_action_distribution,
                target_action_distribution=target_action_distribution,
                actions=sampled_actions,
                q_values=sampled_q_t)

        # For clarity, explicitly define which variables are trained by which loss.
        critic_trainable_variables = (
            # In this agent, the critic loss trains the observation network.
            self._observation_network.trainable_variables +
            self._critic_network.trainable_variables)
        policy_trainable_variables = self._policy_network.trainable_variables
        # The following are the MPO dual variables, stored in the loss module.
        dual_trainable_variables = self._policy_loss_module.trainable_variables

        # Compute gradients.
        critic_gradients = tape.gradient(critic_loss,
                                         critic_trainable_variables)
        policy_gradients, dual_gradients = tape.gradient(
            policy_loss,
            (policy_trainable_variables, dual_trainable_variables))

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tuple(
                tf.clip_by_global_norm(policy_gradients, 40.)[0])
            critic_gradients = tuple(
                tf.clip_by_global_norm(critic_gradients, 40.)[0])

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     critic_trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     policy_trainable_variables)
        self._dual_optimizer.apply(dual_gradients, dual_trainable_variables)

        # Losses to track.
        fetches = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }
        fetches.update(policy_stats)  # Log MPO stats.

        return fetches
Пример #8
0
def combine_dim(tensor: tf.Tensor) -> tf.Tensor:
    dims = tensor.shape[:2]
    return snt.merge_leading_dims(tensor, num_dims=2), dims