Esempio n. 1
0
    def _step(self) -> Dict[str, tf.Tensor]:
        """Do a step of SGD and update the priorities."""

        # Pull out the data needed for updates/priorities.
        inputs = next(self._iterator)
        transitions: types.Transition = inputs.data
        keys, probs = inputs.info[:2]

        with tf.GradientTape() as tape:
            # Evaluate our networks.
            q_tm1 = self._network(transitions.observation)
            q_t_value = self._target_network(transitions.next_observation)
            q_t_selector = self._network(transitions.next_observation)

            # The rewards and discounts have to have the same type as network values.
            r_t = tf.cast(transitions.reward, q_tm1.dtype)
            if self._max_abs_reward:
                r_t = tf.clip_by_value(r_t, -self._max_abs_reward,
                                       self._max_abs_reward)
            d_t = tf.cast(transitions.discount, q_tm1.dtype) * tf.cast(
                self._discount, q_tm1.dtype)

            # Compute the loss.
            _, extra = trfl.double_qlearning(q_tm1, transitions.action, r_t,
                                             d_t, q_t_value, q_t_selector)
            loss = losses.huber(extra.td_error, self._huber_loss_parameter)

            # Get the importance weights.
            importance_weights = 1. / probs  # [B]
            importance_weights **= self._importance_sampling_exponent
            importance_weights /= tf.reduce_max(importance_weights)

            # Reweight.
            loss *= tf.cast(importance_weights, loss.dtype)  # [B]
            loss = tf.reduce_mean(loss, axis=[0])  # []

        # Do a step of SGD.
        gradients = tape.gradient(loss, self._network.trainable_variables)
        gradients, _ = tf.clip_by_global_norm(gradients,
                                              self._max_gradient_norm)
        self._optimizer.apply(gradients, self._network.trainable_variables)

        # Get the priorities that we'll use to update.
        priorities = tf.abs(extra.td_error)

        # Periodically update the target network.
        if tf.math.mod(self._num_steps, self._target_update_period) == 0:
            for src, dest in zip(self._network.variables,
                                 self._target_network.variables):
                dest.assign(src)
        self._num_steps.assign_add(1)

        # Report loss & statistics for logging.
        fetches = {
            'loss': loss,
            'keys': keys,
            'priorities': priorities,
        }

        return fetches
Esempio n. 2
0
    def _step(self) -> Dict[str, tf.Tensor]:
        """Do a step of SGD and update the priorities."""

        # Pull out the data needed for updates/priorities.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t = inputs.data
        keys, probs = inputs.info[:2]

        with tf.GradientTape() as tape:
            # Evaluate our networks.
            q_tm1 = self._network(o_tm1)
            q_t_value = self._target_network(o_t)
            q_t_selector = self._network(o_t)

            # The rewards and discounts have to have the same type as network values.
            r_t = tf.cast(r_t, q_tm1.dtype)
            r_t = tf.clip_by_value(r_t, -1., 1.)
            d_t = tf.cast(d_t, q_tm1.dtype) * tf.cast(self._discount,
                                                      q_tm1.dtype)

            # Compute the loss.
            _, extra = trfl.double_qlearning(q_tm1, a_tm1, r_t, d_t, q_t_value,
                                             q_t_selector)
            loss = losses.huber(extra.td_error, self._huber_loss_parameter)

            # Get the importance weights.
            importance_weights = 1. / probs  # [B]
            importance_weights **= self._importance_sampling_exponent
            importance_weights /= tf.reduce_max(importance_weights)

            # Reweight.
            loss *= tf.cast(importance_weights, loss.dtype)  # [B]
            loss = tf.reduce_mean(loss, axis=[0])  # []

        # Do a step of SGD.
        gradients = tape.gradient(loss, self._network.trainable_variables)
        self._optimizer.apply(gradients, self._network.trainable_variables)

        # Update the priorities in the replay buffer.
        if self._replay_client:
            priorities = tf.cast(tf.abs(extra.td_error), tf.float64)
            self._replay_client.update_priorities(
                table=adders.DEFAULT_PRIORITY_TABLE,
                keys=keys,
                priorities=priorities)

        # Periodically update the target network.
        if tf.math.mod(self._num_steps, self._target_update_period) == 0:
            for src, dest in zip(self._network.variables,
                                 self._target_network.variables):
                dest.assign(src)
        self._num_steps.assign_add(1)

        # Report loss & statistics for logging.
        fetches = {
            'loss': loss,
        }

        return fetches
Esempio n. 3
0
    def _step(self):
        # Update target network.
        online_variables = (
            *self._observation_network.variables,
            *self._critic_network.variables,
            *self._policy_network.variables,
        )
        target_variables = (
            *self._target_observation_network.variables,
            *self._target_critic_network.variables,
            *self._target_policy_network.variables,
        )
        # Make online -> target network update ops.
        if self._target_update_period > 0 and \
           tf.math.mod(self._num_steps, self._target_update_period) == 0:
            for src, dest in zip(online_variables, target_variables):
                dest.assign(src)
        self._num_steps.assign_add(1)

        # Get data from replay (dropping extras if any). Note there is no
        # extra data here because we do not insert any into Reverb.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t, extra = inputs.data
        behavior_logP_tm1 = extra['logP']
        behavior_tm1 = extra['policy']

        # Cast the additional discount to match the environment discount dtype.
        discount = tf.cast(self._discount, dtype=d_t.dtype)

        with tf.GradientTape(persistent=True) as tape:
            # Maybe transform the observation before feeding into policy and critic.
            # Transforming the observations this way at the start of the learning
            # step effectively means that the policy and critic share observation
            # network weights.
            o_tm1 = self._observation_network(o_tm1)
            o_t = self._target_observation_network(o_t)
            o_t = tree.map_structure(tf.stop_gradient, o_t)

            # Policy
            pol_tm1, v_tm1 = self._policy_network(o_tm1)
            pol_t, v_t = self._target_policy_network(o_t)
            pol_t = tree.map_structure(tf.stop_gradient, pol_t)
            v_t = tree.map_structure(tf.stop_gradient, v_t)

            # Actor loss. If clipping is true use dqda clipping and clip the norm.
            # TODO: two critic nets, e.g. q1_tm1 and q2_tm1, pick the min as target
            # DPG loss. If clipping is true use dqda clipping and clip the norm.
            dqda_clipping = 1.0 if self._clipping else None
            onpol_a_tm1, onpol_logP_tm1 = self._sampling_head(pol_tm1)
            onpol_q_tm1 = self._critic_network(o_tm1, onpol_a_tm1)
            onpol_q_tm1 = tf.squeeze(onpol_q_tm1, axis=-1)  # [B]

            logP_tm1 = self._sampling_head.log_prob(a_tm1, pol_tm1)
            ReFER_params_loss = self._ReFER.loss(behavior_logP_tm1, logP_tm1)

            dpg_loss = losses.dpg(onpol_q_tm1,
                                  onpol_a_tm1,
                                  tape=tape,
                                  dqda_clipping=dqda_clipping,
                                  clip_norm=self._clipping)
            dpg_loss = tf.reduce_mean(dpg_loss, axis=0)
            entropy_loss = self._entropy_coeff * tf.reduce_mean(onpol_logP_tm1,
                                                                axis=0)

            KL_coef = self._ReFER.DKL_coef()
            #behavior_P_tm1 = tf.math.exp(behavior_logP_tm1)
            #KL_loss = KL_coef * behavior_P_tm1 * (behavior_logP_tm1 - logP_tm1)
            KL_loss = tf.reduce_sum((behavior_tm1 - pol_tm1)**2, axis=-1)
            KL_loss = KL_coef * tf.reduce_mean(KL_loss, axis=0)

            # V(s) loss
            value_target = tf.stop_gradient(onpol_q_tm1 - self._entropy_coeff *
                                            onpol_logP_tm1)

            value_loss = losses.huber(value_target - v_tm1, 1.0)
            #value_loss = 0.5 * (value_target - v_tm1) ** 2
            value_loss = tf.reduce_mean(value_loss, axis=0)

            # Critic learning with TD loss
            q_tm1 = self._critic_network(o_tm1, a_tm1)
            q_tm1 = tf.squeeze(q_tm1, axis=-1)  # [B]

            onpol_a_t, logP_t = self._sampling_head(pol_t)
            onpol_q_t = self._target_critic_network(o_t, onpol_a_t)
            onpol_q_t = tf.squeeze(onpol_q_t, axis=-1)  # [B]
            onpol_q_t = tree.map_structure(tf.stop_gradient, onpol_q_t)

            R_t = self._observation_network.scale_rewards(r_t)
            critic_target = tf.stop_gradient(R_t +
                                             d_t * tf.minimum(v_t, onpol_q_t))
            #critic_target = tf.stop_gradient(R_t + d_t * 0.5*(v_t + onpol_q_t))

            critic_loss = losses.huber(critic_target - q_tm1, 1.0)
            #critic_loss = 0.5 * (critic_target - q_tm1) ** 2
            critic_loss = tf.reduce_mean(critic_loss, axis=0)

            encoder_loss = self._observation_network.compute_loss(o_tm1, r_t)

            policy_loss = value_loss + entropy_loss + dpg_loss + encoder_loss + KL_loss

        # Compute gradients.
        policy_gradients = tape.gradient(policy_loss, self._policy_variables)
        critic_gradients = tape.gradient(critic_loss, self._critic_variables)
        ReFER_gradient = tape.gradient(ReFER_params_loss,
                                       self._ReFER.trainable_variables)

        # Delete the tape manually because of the persistent=True flag.
        del tape

        # Maybe clip gradients.
        if self._clipping:
            policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0]
            critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0]

        # Apply gradients.
        self._policy_optimizer.apply(policy_gradients, self._policy_variables)
        self._critic_optimizer.apply(critic_gradients, self._critic_variables)
        self._ReFER_optimizer.apply(ReFER_gradient,
                                    self._ReFER.trainable_variables)

        # Losses to track.
        return {
            'critic_loss': critic_loss,
            'svalue_loss': value_loss,
            'entropy_loss': entropy_loss,
            'dpg_loss': dpg_loss,
            'avg_q': tf.reduce_mean(onpol_q_t, axis=0),
            'KL_loss': KL_loss,
            #'frac_off_pol': self._ReFER._last_frac_off_pol,
            'beta': self._ReFER._beta,
            'r_mean': self._observation_network._ret_mean,
            'r_scale': self._observation_network._ret_scale,
        }
Esempio n. 4
0
    def _step(self) -> Dict[str, tf.Tensor]:
        """Do a step of SGD and update the priorities."""

        # Pull out the data needed for updates/priorities.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t = inputs.data
        keys, probs = inputs.info[:2]

        with tf.GradientTape() as tape:
            # Evaluate our networks.
            q_tm1 = self._network(o_tm1)
            q_t_value = self._target_network(o_t)
            q_t_selector = self._network(o_t)

            # The rewards and discounts have to have the same type as network values.
            r_t = tf.cast(r_t, q_tm1.dtype)
            r_t = tf.clip_by_value(r_t, -1., 1.)
            d_t = tf.cast(d_t, q_tm1.dtype) * tf.cast(self._discount,
                                                      q_tm1.dtype)

            # Compute the loss.
            _, extra = trfl.double_qlearning(q_tm1, a_tm1, r_t, d_t, q_t_value,
                                             q_t_selector)
            loss = losses.huber(extra.td_error, self._huber_loss_parameter)

            if self._alpha:
                policy_probs = self._emp_policy.lookup([str(o) for o in o_tm1])

                push_down = tf.reduce_logsumexp(
                    q_tm1 * self._tr,
                    axis=1) / self._tr  # soft-maximum of the q func
                push_up = tf.reduce_sum(
                    policy_probs * q_tm1,
                    axis=1)  # expected q value under behavioural policy

                cql_loss = loss + self._alpha * (push_down - push_up)
            else:
                cql_loss = loss

            cql_loss = tf.reduce_mean(cql_loss, axis=0)

        # Do a step of SGD.
        gradients = tape.gradient(cql_loss, self._network.trainable_variables)
        self._optimizer.apply(gradients, self._network.trainable_variables)

        # Update the priorities in the replay buffer.
        if self._replay_client:
            priorities = tf.cast(tf.abs(extra.td_error), tf.float64)
            self._replay_client.update_priorities(
                table=adders.DEFAULT_PRIORITY_TABLE,
                keys=keys,
                priorities=priorities)

        # Periodically update the target network.
        if tf.math.mod(self._counter.get_counts()['learner_steps'],
                       self._target_update_period) == 0:
            for src, dest in zip(self._network.variables,
                                 self._target_network.variables):
                dest.assign(src)

        # Report loss & statistics for logging.
        fetches = {
            'critic_loss':
            tf.reduce_mean(loss, axis=0),
            'q_variance':
            tf.reduce_mean(tf.math.reduce_variance(q_tm1, axis=1), axis=0),
            'q_average':
            tf.reduce_mean(q_tm1)
        }
        if self._alpha:
            fetches.update({
                'push_up':
                tf.reduce_mean(push_up, axis=0),
                'push_down':
                tf.reduce_mean(push_down, axis=0),
                'regularizer':
                tf.reduce_mean(push_down - push_up, axis=0),
                'cql_loss':
                cql_loss,
            })
        return fetches