def _step(self) -> Dict[str, tf.Tensor]: """Do a step of SGD and update the priorities.""" # Pull out the data needed for updates/priorities. inputs = next(self._iterator) transitions: types.Transition = inputs.data keys, probs = inputs.info[:2] with tf.GradientTape() as tape: # Evaluate our networks. q_tm1 = self._network(transitions.observation) q_t_value = self._target_network(transitions.next_observation) q_t_selector = self._network(transitions.next_observation) # The rewards and discounts have to have the same type as network values. r_t = tf.cast(transitions.reward, q_tm1.dtype) if self._max_abs_reward: r_t = tf.clip_by_value(r_t, -self._max_abs_reward, self._max_abs_reward) d_t = tf.cast(transitions.discount, q_tm1.dtype) * tf.cast( self._discount, q_tm1.dtype) # Compute the loss. _, extra = trfl.double_qlearning(q_tm1, transitions.action, r_t, d_t, q_t_value, q_t_selector) loss = losses.huber(extra.td_error, self._huber_loss_parameter) # Get the importance weights. importance_weights = 1. / probs # [B] importance_weights **= self._importance_sampling_exponent importance_weights /= tf.reduce_max(importance_weights) # Reweight. loss *= tf.cast(importance_weights, loss.dtype) # [B] loss = tf.reduce_mean(loss, axis=[0]) # [] # Do a step of SGD. gradients = tape.gradient(loss, self._network.trainable_variables) gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) self._optimizer.apply(gradients, self._network.trainable_variables) # Get the priorities that we'll use to update. priorities = tf.abs(extra.td_error) # Periodically update the target network. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(self._network.variables, self._target_network.variables): dest.assign(src) self._num_steps.assign_add(1) # Report loss & statistics for logging. fetches = { 'loss': loss, 'keys': keys, 'priorities': priorities, } return fetches
def _step(self) -> Dict[str, tf.Tensor]: """Do a step of SGD and update the priorities.""" # Pull out the data needed for updates/priorities. inputs = next(self._iterator) o_tm1, a_tm1, r_t, d_t, o_t = inputs.data keys, probs = inputs.info[:2] with tf.GradientTape() as tape: # Evaluate our networks. q_tm1 = self._network(o_tm1) q_t_value = self._target_network(o_t) q_t_selector = self._network(o_t) # The rewards and discounts have to have the same type as network values. r_t = tf.cast(r_t, q_tm1.dtype) r_t = tf.clip_by_value(r_t, -1., 1.) d_t = tf.cast(d_t, q_tm1.dtype) * tf.cast(self._discount, q_tm1.dtype) # Compute the loss. _, extra = trfl.double_qlearning(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) loss = losses.huber(extra.td_error, self._huber_loss_parameter) # Get the importance weights. importance_weights = 1. / probs # [B] importance_weights **= self._importance_sampling_exponent importance_weights /= tf.reduce_max(importance_weights) # Reweight. loss *= tf.cast(importance_weights, loss.dtype) # [B] loss = tf.reduce_mean(loss, axis=[0]) # [] # Do a step of SGD. gradients = tape.gradient(loss, self._network.trainable_variables) self._optimizer.apply(gradients, self._network.trainable_variables) # Update the priorities in the replay buffer. if self._replay_client: priorities = tf.cast(tf.abs(extra.td_error), tf.float64) self._replay_client.update_priorities( table=adders.DEFAULT_PRIORITY_TABLE, keys=keys, priorities=priorities) # Periodically update the target network. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(self._network.variables, self._target_network.variables): dest.assign(src) self._num_steps.assign_add(1) # Report loss & statistics for logging. fetches = { 'loss': loss, } return fetches
def _step(self): # Update target network. online_variables = ( *self._observation_network.variables, *self._critic_network.variables, *self._policy_network.variables, ) target_variables = ( *self._target_observation_network.variables, *self._target_critic_network.variables, *self._target_policy_network.variables, ) # Make online -> target network update ops. if self._target_update_period > 0 and \ tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(online_variables, target_variables): dest.assign(src) self._num_steps.assign_add(1) # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. inputs = next(self._iterator) o_tm1, a_tm1, r_t, d_t, o_t, extra = inputs.data behavior_logP_tm1 = extra['logP'] behavior_tm1 = extra['policy'] # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=d_t.dtype) with tf.GradientTape(persistent=True) as tape: # Maybe transform the observation before feeding into policy and critic. # Transforming the observations this way at the start of the learning # step effectively means that the policy and critic share observation # network weights. o_tm1 = self._observation_network(o_tm1) o_t = self._target_observation_network(o_t) o_t = tree.map_structure(tf.stop_gradient, o_t) # Policy pol_tm1, v_tm1 = self._policy_network(o_tm1) pol_t, v_t = self._target_policy_network(o_t) pol_t = tree.map_structure(tf.stop_gradient, pol_t) v_t = tree.map_structure(tf.stop_gradient, v_t) # Actor loss. If clipping is true use dqda clipping and clip the norm. # TODO: two critic nets, e.g. q1_tm1 and q2_tm1, pick the min as target # DPG loss. If clipping is true use dqda clipping and clip the norm. dqda_clipping = 1.0 if self._clipping else None onpol_a_tm1, onpol_logP_tm1 = self._sampling_head(pol_tm1) onpol_q_tm1 = self._critic_network(o_tm1, onpol_a_tm1) onpol_q_tm1 = tf.squeeze(onpol_q_tm1, axis=-1) # [B] logP_tm1 = self._sampling_head.log_prob(a_tm1, pol_tm1) ReFER_params_loss = self._ReFER.loss(behavior_logP_tm1, logP_tm1) dpg_loss = losses.dpg(onpol_q_tm1, onpol_a_tm1, tape=tape, dqda_clipping=dqda_clipping, clip_norm=self._clipping) dpg_loss = tf.reduce_mean(dpg_loss, axis=0) entropy_loss = self._entropy_coeff * tf.reduce_mean(onpol_logP_tm1, axis=0) KL_coef = self._ReFER.DKL_coef() #behavior_P_tm1 = tf.math.exp(behavior_logP_tm1) #KL_loss = KL_coef * behavior_P_tm1 * (behavior_logP_tm1 - logP_tm1) KL_loss = tf.reduce_sum((behavior_tm1 - pol_tm1)**2, axis=-1) KL_loss = KL_coef * tf.reduce_mean(KL_loss, axis=0) # V(s) loss value_target = tf.stop_gradient(onpol_q_tm1 - self._entropy_coeff * onpol_logP_tm1) value_loss = losses.huber(value_target - v_tm1, 1.0) #value_loss = 0.5 * (value_target - v_tm1) ** 2 value_loss = tf.reduce_mean(value_loss, axis=0) # Critic learning with TD loss q_tm1 = self._critic_network(o_tm1, a_tm1) q_tm1 = tf.squeeze(q_tm1, axis=-1) # [B] onpol_a_t, logP_t = self._sampling_head(pol_t) onpol_q_t = self._target_critic_network(o_t, onpol_a_t) onpol_q_t = tf.squeeze(onpol_q_t, axis=-1) # [B] onpol_q_t = tree.map_structure(tf.stop_gradient, onpol_q_t) R_t = self._observation_network.scale_rewards(r_t) critic_target = tf.stop_gradient(R_t + d_t * tf.minimum(v_t, onpol_q_t)) #critic_target = tf.stop_gradient(R_t + d_t * 0.5*(v_t + onpol_q_t)) critic_loss = losses.huber(critic_target - q_tm1, 1.0) #critic_loss = 0.5 * (critic_target - q_tm1) ** 2 critic_loss = tf.reduce_mean(critic_loss, axis=0) encoder_loss = self._observation_network.compute_loss(o_tm1, r_t) policy_loss = value_loss + entropy_loss + dpg_loss + encoder_loss + KL_loss # Compute gradients. policy_gradients = tape.gradient(policy_loss, self._policy_variables) critic_gradients = tape.gradient(critic_loss, self._critic_variables) ReFER_gradient = tape.gradient(ReFER_params_loss, self._ReFER.trainable_variables) # Delete the tape manually because of the persistent=True flag. del tape # Maybe clip gradients. if self._clipping: policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0] critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0] # Apply gradients. self._policy_optimizer.apply(policy_gradients, self._policy_variables) self._critic_optimizer.apply(critic_gradients, self._critic_variables) self._ReFER_optimizer.apply(ReFER_gradient, self._ReFER.trainable_variables) # Losses to track. return { 'critic_loss': critic_loss, 'svalue_loss': value_loss, 'entropy_loss': entropy_loss, 'dpg_loss': dpg_loss, 'avg_q': tf.reduce_mean(onpol_q_t, axis=0), 'KL_loss': KL_loss, #'frac_off_pol': self._ReFER._last_frac_off_pol, 'beta': self._ReFER._beta, 'r_mean': self._observation_network._ret_mean, 'r_scale': self._observation_network._ret_scale, }
def _step(self) -> Dict[str, tf.Tensor]: """Do a step of SGD and update the priorities.""" # Pull out the data needed for updates/priorities. inputs = next(self._iterator) o_tm1, a_tm1, r_t, d_t, o_t = inputs.data keys, probs = inputs.info[:2] with tf.GradientTape() as tape: # Evaluate our networks. q_tm1 = self._network(o_tm1) q_t_value = self._target_network(o_t) q_t_selector = self._network(o_t) # The rewards and discounts have to have the same type as network values. r_t = tf.cast(r_t, q_tm1.dtype) r_t = tf.clip_by_value(r_t, -1., 1.) d_t = tf.cast(d_t, q_tm1.dtype) * tf.cast(self._discount, q_tm1.dtype) # Compute the loss. _, extra = trfl.double_qlearning(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) loss = losses.huber(extra.td_error, self._huber_loss_parameter) if self._alpha: policy_probs = self._emp_policy.lookup([str(o) for o in o_tm1]) push_down = tf.reduce_logsumexp( q_tm1 * self._tr, axis=1) / self._tr # soft-maximum of the q func push_up = tf.reduce_sum( policy_probs * q_tm1, axis=1) # expected q value under behavioural policy cql_loss = loss + self._alpha * (push_down - push_up) else: cql_loss = loss cql_loss = tf.reduce_mean(cql_loss, axis=0) # Do a step of SGD. gradients = tape.gradient(cql_loss, self._network.trainable_variables) self._optimizer.apply(gradients, self._network.trainable_variables) # Update the priorities in the replay buffer. if self._replay_client: priorities = tf.cast(tf.abs(extra.td_error), tf.float64) self._replay_client.update_priorities( table=adders.DEFAULT_PRIORITY_TABLE, keys=keys, priorities=priorities) # Periodically update the target network. if tf.math.mod(self._counter.get_counts()['learner_steps'], self._target_update_period) == 0: for src, dest in zip(self._network.variables, self._target_network.variables): dest.assign(src) # Report loss & statistics for logging. fetches = { 'critic_loss': tf.reduce_mean(loss, axis=0), 'q_variance': tf.reduce_mean(tf.math.reduce_variance(q_tm1, axis=1), axis=0), 'q_average': tf.reduce_mean(q_tm1) } if self._alpha: fetches.update({ 'push_up': tf.reduce_mean(push_up, axis=0), 'push_down': tf.reduce_mean(push_down, axis=0), 'regularizer': tf.reduce_mean(push_down - push_up, axis=0), 'cql_loss': cql_loss, }) return fetches