Ejemplo n.º 1
0
    def _call(
        self, observation_and_state: Tuple[types.NestedTensor,
                                           PolicyCriticRNNState]
    ) -> Tuple[types.NestedTensor, PolicyCriticRNNState]:
        """Computes a forward step for a single element.

    The observation and state are packed together in order to use
    `tf.vectorized_map` to handle batches of observations.
    See this module's __call__() function.

    Args:
      observation_and_state: the observation and state packed in a tuple.

    Returns:
      The selected action and the corresponding state.
    """
        observation, prev_state = observation_and_state

        # Tile input observations and states to allow multiple policy predictions.
        tiled_observation, tiled_prev_state = utils.tile_nested(
            (observation, prev_state), self._num_action_samples)
        actions, policy_states = self._policy_network(tiled_observation,
                                                      tiled_prev_state.policy)

        # Evaluate multiple critic predictions with the sampled actions.
        value_distribution, critic_states = self._critic_network(
            tiled_observation, actions, tiled_prev_state.critic)
        value_estimate = value_distribution.mean()

        # Resample a single action of the sampled actions according to logits given
        # by the tempered Q-values.
        selected_action_idx = tfp.distributions.Categorical(
            probs=tf.nn.softmax(value_estimate /
                                self._temperature_beta)).sample()
        selected_action = actions[selected_action_idx]

        # Select and return the RNN state that corresponds to the selected action.
        states = PolicyCriticRNNState(policy=policy_states,
                                      critic=critic_states)
        selected_state = tree.map_structure(lambda x: x[selected_action_idx],
                                            states)

        return selected_action, selected_state
Ejemplo n.º 2
0
    def __call__(self, inputs: types.NestedTensor) -> tf.Tensor:
        # Inputs are of size [B, ...]. Here we tile them to be of shape [N, B, ...].
        tiled_inputs = tf2_utils.tile_nested(inputs, self._num_action_samples)
        shape = tf.shape(tree.flatten(tiled_inputs)[0])
        n, b = shape[0], shape[1]
        tf.debugging.assert_equal(
            n, self._num_action_samples,
            'Internal Error. Unexpected tiled_inputs shape.')
        dummy_zeros_n_b = tf.zeros((n, b))
        # Reshape to [N * B, ...].
        merge = lambda x: snt.merge_leading_dims(x, 2)
        tiled_inputs = tree.map_structure(merge, tiled_inputs)

        tiled_actions = self._actor_network(tiled_inputs)

        # Compute Q-values and the resulting tempered probabilities.
        q = self._critic_network(tiled_inputs, tiled_actions)
        boltzmann_logits = q / self._beta

        boltzmann_logits = snt.split_leading_dim(boltzmann_logits,
                                                 dummy_zeros_n_b, 2)
        # [B, N]
        boltzmann_logits = tf.transpose(boltzmann_logits, perm=(1, 0))
        # Resample one action per batch according to the Boltzmann distribution.
        action_idx = tfp.distributions.Categorical(
            logits=boltzmann_logits).sample()
        # [B, 2], where the first column is 0, 1, 2,... corresponding to indices to
        # the batch dimension.
        action_idx = tf.stack((tf.range(b), action_idx), axis=1)

        tiled_actions = snt.split_leading_dim(tiled_actions, dummy_zeros_n_b,
                                              2)
        action_dim = len(tiled_actions.get_shape().as_list())
        tiled_actions = tf.transpose(tiled_actions,
                                     perm=[1, 0] + list(range(2, action_dim)))
        # [B, ...]
        action_sample = tf.gather_nd(tiled_actions, action_idx)

        return action_sample
Ejemplo n.º 3
0
    def _step(self) -> Dict[str, tf.Tensor]:
        # Update target network
        online_variables = [
            *self._critic_network.variables,
            *self._policy_network.variables,
        ]
        if self._prior_network is not None:
            online_variables += [*self._prior_network.variables]
        online_variables = tuple(online_variables)

        target_variables = [
            *self._target_critic_network.variables,
            *self._target_policy_network.variables,
        ]
        if self._prior_network is not None:
            target_variables += [*self._target_prior_network.variables]
        target_variables = tuple(target_variables)

        # Make online -> target network update ops.
        if tf.math.mod(self._num_steps, self._target_update_period) == 0:
            for src, dest in zip(online_variables, target_variables):
                dest.assign(src)
        self._num_steps.assign_add(1)

        # Get data from replay (dropping extras if any) and flip to `[T, B, ...]`.
        sample: reverb.ReplaySample = next(self._iterator)
        data = tf2_utils.batch_to_sequence(sample.data)
        observations, actions, rewards, discounts, extra = (data.observation,
                                                            data.action,
                                                            data.reward,
                                                            data.discount,
                                                            data.extras)
        online_target_pi_q = svg0_utils.OnlineTargetPiQ(
            online_pi=self._policy_network,
            online_q=self._critic_network,
            target_pi=self._target_policy_network,
            target_q=self._target_critic_network,
            num_samples=self._num_action_samples,
            online_prior=self._prior_network,
            target_prior=self._target_prior_network,
        )
        with tf.GradientTape(persistent=True) as tape:
            step_outputs = svg0_utils.static_rnn(
                core=online_target_pi_q,
                inputs=(observations, actions),
                unroll_length=rewards.shape[0])

            # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the
            # number of action samples taken.
            target_pi_samples = tf2_utils.batch_to_sequence(
                step_outputs.target_samples)
            # Tile observations to have shape [S, T+1, B,..].
            tiled_observations = tf2_utils.tile_nested(
                observations, self._num_action_samples)

            # Finally compute target Q values on the new action samples.
            # Shape: [S, T+1, B, 1]
            target_q_target_pi_samples = snt.BatchApply(
                self._target_critic_network, 3)(tiled_observations,
                                                target_pi_samples)
            # Compute the value estimate by averaging over the action dimension.
            # Shape: [T+1, B, 1].
            target_v_target_pi = tf.reduce_mean(target_q_target_pi_samples,
                                                axis=0)

            # Split the target V's into the target for learning
            # `value_function_target` and the bootstrap value. Shape: [T, B].
            value_function_target = tf.squeeze(target_v_target_pi[:-1],
                                               axis=-1)
            # Shape: [B].
            bootstrap_value = tf.squeeze(target_v_target_pi[-1], axis=-1)

            # When learning with a prior, add entropy terms to value targets.
            if self._prior_network is not None:
                value_function_target -= self._distillation_cost * tf.stop_gradient(
                    step_outputs.analytic_kl_to_target[:-1])
                bootstrap_value -= self._distillation_cost * tf.stop_gradient(
                    step_outputs.analytic_kl_to_target[-1])

            # Get target log probs and behavior log probs from rollout.
            # Shape: [T+1, B].
            target_log_probs_behavior_actions = (
                step_outputs.target_log_probs_behavior_actions)
            behavior_log_probs = extra['log_prob']
            # Calculate importance weights. Shape: [T+1, B].
            rhos = tf.exp(target_log_probs_behavior_actions -
                          behavior_log_probs)

            # Filter the importance weights to mask out episode restarts. Ignore the
            # last action and consider the step type of the next step for masking.
            # Shape: [T, B].
            episode_start_mask = tf2_utils.batch_to_sequence(
                sample.data.start_of_episode)[1:]

            rhos = svg0_utils.mask_out_restarting(rhos[:-1],
                                                  episode_start_mask)

            # rhos = rhos[:-1]
            # Compute the log importance weights with a small value added for
            # stability.
            # Shape: [T, B]
            log_rhos = tf.math.log(rhos + _MIN_LOG_VAL)

            # Retrieve the target and online Q values and throw away the last action.
            # Shape: [T, B].
            target_q_values = tf.squeeze(step_outputs.target_q[:-1], -1)
            online_q_values = tf.squeeze(step_outputs.online_q[:-1], -1)

            # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the
            # number of action samples taken.
            online_pi_samples = tf2_utils.batch_to_sequence(
                step_outputs.online_samples)
            target_q_online_pi_samples = snt.BatchApply(
                self._target_critic_network, 3)(tiled_observations,
                                                online_pi_samples)
            expected_q = tf.reduce_mean(tf.squeeze(target_q_online_pi_samples,
                                                   -1),
                                        axis=0)

            # Flip online_log_probs to be of shape [S, T+1, B] and then compute
            # entropy by averaging over num samples. Final shape: [T+1, B].
            online_log_probs = tf2_utils.batch_to_sequence(
                step_outputs.online_log_probs)
            sample_based_entropy = tf.reduce_mean(-online_log_probs, axis=0)
            retrace_outputs = continuous_retrace_ops.retrace_from_importance_weights(
                log_rhos=log_rhos,
                discounts=self._discount * discounts[:-1],
                rewards=rewards[:-1],
                q_values=target_q_values,
                values=value_function_target,
                bootstrap_value=bootstrap_value,
                lambda_=self._lambda,
            )

            # Critic loss. Shape: [T, B].
            critic_loss = 0.5 * tf.math.squared_difference(
                tf.stop_gradient(retrace_outputs.qs), online_q_values)

            # Policy loss- SVG0 with sample based entropy. Shape: [T, B]
            policy_loss = -(expected_q + self._entropy_regularizer_cost *
                            sample_based_entropy)
            policy_loss = policy_loss[:-1]

            if self._prior_network is not None:
                # When training the prior, also add the per-timestep KL cost.
                policy_loss += (self._distillation_cost *
                                step_outputs.analytic_kl_to_target[:-1])

            # Ensure episode restarts are masked out when computing the losses.
            critic_loss = svg0_utils.mask_out_restarting(
                critic_loss, episode_start_mask)
            critic_loss = tf.reduce_mean(critic_loss)

            policy_loss = svg0_utils.mask_out_restarting(
                policy_loss, episode_start_mask)
            policy_loss = tf.reduce_mean(policy_loss)

            if self._prior_network is not None:
                prior_loss = step_outputs.analytic_kl_divergence[:-1]
                prior_loss = svg0_utils.mask_out_restarting(
                    prior_loss, episode_start_mask)
                prior_loss = tf.reduce_mean(prior_loss)

        # Get trainable variables.
        policy_variables = self._policy_network.trainable_variables
        critic_variables = self._critic_network.trainable_variables

        # Compute gradients.
        policy_gradients = tape.gradient(policy_loss, policy_variables)
        critic_gradients = tape.gradient(critic_loss, critic_variables)
        if self._prior_network is not None:
            prior_variables = self._prior_network.trainable_variables
            prior_gradients = tape.gradient(prior_loss, prior_variables)

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Apply gradients.
        self._policy_optimizer.apply(policy_gradients, policy_variables)
        self._critic_optimizer.apply(critic_gradients, critic_variables)
        losses = {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
        }

        if self._prior_network is not None:
            self._prior_optimizer.apply(prior_gradients, prior_variables)
            losses['prior_loss'] = prior_loss

        # Losses to track.
        return losses
Ejemplo n.º 4
0
    def _step(self, sample: reverb.ReplaySample) -> Dict[str, tf.Tensor]:
        # Transpose batch and sequence axes, i.e. [B, T, ...] to [T, B, ...].
        sample = tf2_utils.batch_to_sequence(sample)
        observations = sample.observation
        actions = sample.action
        rewards = sample.reward
        discounts = sample.discount

        dtype = rewards.dtype

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=discounts.dtype)

        # Loss cumulants across time. These cannot be python mutable objects.
        critic_loss = 0.
        policy_loss = 0.

        # Each transition induces a policy loss, which we then weight using
        # the `policy_loss_coef_t`; shape [B], see https://arxiv.org/abs/2006.15134.
        # `policy_loss_coef` is a scalar average of these coefficients across
        # the batch and sequence length dimensions.
        policy_loss_coef = 0.

        per_device_batch_size = actions.shape[1]

        # Initialize recurrent states.
        critic_state = self._critic_network.initial_state(
            per_device_batch_size)
        target_critic_state = critic_state
        policy_state = self._policy_network.initial_state(
            per_device_batch_size)
        target_policy_state = policy_state

        with tf.GradientTape(persistent=True) as tape:
            for t in range(1, self._sequence_length):
                o_tm1 = tree.map_structure(operator.itemgetter(t - 1),
                                           observations)
                a_tm1 = tree.map_structure(operator.itemgetter(t - 1), actions)
                r_t = tree.map_structure(operator.itemgetter(t - 1), rewards)
                d_t = tree.map_structure(operator.itemgetter(t - 1), discounts)
                o_t = tree.map_structure(operator.itemgetter(t), observations)

                if t != 1:
                    # By only updating the target critic state here we are forcing
                    # the target critic to ignore observations[0]. Otherwise, the
                    # target_critic will be unrolled for one more timestep than critic.
                    # The smaller the sequence length, the more problematic this is: if
                    # you use RNN on sequences of length 2, you would expect the code to
                    # never use recurrent connections. But if you don't skip updating the
                    # target_critic_state on observation[0] here, it won't be the case.
                    _, target_critic_state = self._target_critic_network(
                        o_tm1, a_tm1, target_critic_state)

                # ========================= Critic learning ============================
                q_tm1, next_critic_state = self._critic_network(
                    o_tm1, a_tm1, critic_state)
                target_action_distribution, target_policy_state = self._target_policy_network(
                    o_t, target_policy_state)

                sampled_actions_t = target_action_distribution.sample(
                    self._num_action_samples_td_learning)
                # [N, B, ...]
                tiled_o_t = tf2_utils.tile_nested(
                    o_t, self._num_action_samples_td_learning)
                tiled_target_critic_state = tf2_utils.tile_nested(
                    target_critic_state, self._num_action_samples_td_learning)

                # Compute the target critic's Q-value of the sampled actions.
                sampled_q_t, _ = snt.BatchApply(self._target_critic_network)(
                    tiled_o_t, sampled_actions_t, tiled_target_critic_state)

                # Compute average logits by first reshaping them to [N, B, A] and then
                # normalizing them across atoms.
                new_shape = [
                    self._num_action_samples_td_learning, r_t.shape[0], -1
                ]
                sampled_logits = tf.reshape(sampled_q_t.logits, new_shape)
                sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1)
                averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0)

                # Construct the expected distributional value for bootstrapping.
                q_t = networks.DiscreteValuedDistribution(
                    values=sampled_q_t.values, logits=averaged_logits)
                critic_loss_t = losses.categorical(q_tm1, r_t, discount * d_t,
                                                   q_t)
                critic_loss_t = tf.reduce_mean(critic_loss_t)

                # ========================= Actor learning =============================
                action_distribution_tm1, policy_state = self._policy_network(
                    o_tm1, policy_state)
                q_tm1_mean = q_tm1.mean()

                # Compute the estimate of the value function based on
                # self._num_action_samples_policy_weight samples from the policy.
                tiled_o_tm1 = tf2_utils.tile_nested(
                    o_tm1, self._num_action_samples_policy_weight)
                tiled_critic_state = tf2_utils.tile_nested(
                    critic_state, self._num_action_samples_policy_weight)
                action_tm1 = action_distribution_tm1.sample(
                    self._num_action_samples_policy_weight)
                tiled_z_tm1, _ = snt.BatchApply(self._critic_network)(
                    tiled_o_tm1, action_tm1, tiled_critic_state)
                tiled_v_tm1 = tf.reshape(
                    tiled_z_tm1.mean(),
                    [self._num_action_samples_policy_weight, -1])

                # Use mean, min, or max to aggregate Q(s, a_i), a_i ~ pi(s) into the
                # final estimate of the value function.
                if self._baseline_reduce_function == 'mean':
                    v_tm1_estimate = tf.reduce_mean(tiled_v_tm1, axis=0)
                elif self._baseline_reduce_function == 'max':
                    v_tm1_estimate = tf.reduce_max(tiled_v_tm1, axis=0)
                elif self._baseline_reduce_function == 'min':
                    v_tm1_estimate = tf.reduce_min(tiled_v_tm1, axis=0)

                # Assert that action_distribution_tm1 is a batch of multivariate
                # distributions (in contrast to e.g. a [batch, action_size] collection
                # of 1d distributions).
                assert len(action_distribution_tm1.batch_shape) == 1
                policy_loss_batch = -action_distribution_tm1.log_prob(a_tm1)

                advantage = q_tm1_mean - v_tm1_estimate
                if self._policy_improvement_modes == 'exp':
                    policy_loss_coef_t = tf.math.minimum(
                        tf.math.exp(advantage / self._beta),
                        self._ratio_upper_bound)
                elif self._policy_improvement_modes == 'binary':
                    policy_loss_coef_t = tf.cast(advantage > 0, dtype=dtype)
                elif self._policy_improvement_modes == 'all':
                    # Regress against all actions (effectively pure BC).
                    policy_loss_coef_t = 1.
                policy_loss_coef_t = tf.stop_gradient(policy_loss_coef_t)

                policy_loss_batch *= policy_loss_coef_t
                policy_loss_t = tf.reduce_mean(policy_loss_batch)

                critic_state = next_critic_state

                critic_loss += critic_loss_t
                policy_loss += policy_loss_t
                policy_loss_coef += tf.reduce_mean(
                    policy_loss_coef_t)  # For logging.

            # Divide by sequence length to get mean losses.
            critic_loss /= tf.cast(self._sequence_length, dtype=dtype)
            policy_loss /= tf.cast(self._sequence_length, dtype=dtype)
            policy_loss_coef /= tf.cast(self._sequence_length, dtype=dtype)

        # Compute gradients.
        critic_gradients = tape.gradient(
            critic_loss, self._critic_network.trainable_variables)
        policy_gradients = tape.gradient(
            policy_loss, self._policy_network.trainable_variables)

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Sync gradients across GPUs or TPUs.
        ctx = tf.distribute.get_replica_context()
        critic_gradients = ctx.all_reduce('mean', critic_gradients)
        policy_gradients = ctx.all_reduce('mean', policy_gradients)

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0]
            critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0]

        # Apply gradients.
        self._critic_optimizer.apply(critic_gradients,
                                     self._critic_network.trainable_variables)
        self._policy_optimizer.apply(policy_gradients,
                                     self._policy_network.trainable_variables)

        source_variables = (self._critic_network.variables +
                            self._policy_network.variables)
        target_variables = (self._target_critic_network.variables +
                            self._target_policy_network.variables)

        # Make online -> target network update ops.
        if tf.math.mod(self._num_steps, self._target_update_period) == 0:
            for src, dest in zip(source_variables, target_variables):
                dest.assign(src)
        self._num_steps.assign_add(1)

        return {
            'critic_loss': critic_loss,
            'policy_loss': policy_loss,
            'policy_loss_coef': policy_loss_coef,
        }