def _call( self, observation_and_state: Tuple[types.NestedTensor, PolicyCriticRNNState] ) -> Tuple[types.NestedTensor, PolicyCriticRNNState]: """Computes a forward step for a single element. The observation and state are packed together in order to use `tf.vectorized_map` to handle batches of observations. See this module's __call__() function. Args: observation_and_state: the observation and state packed in a tuple. Returns: The selected action and the corresponding state. """ observation, prev_state = observation_and_state # Tile input observations and states to allow multiple policy predictions. tiled_observation, tiled_prev_state = utils.tile_nested( (observation, prev_state), self._num_action_samples) actions, policy_states = self._policy_network(tiled_observation, tiled_prev_state.policy) # Evaluate multiple critic predictions with the sampled actions. value_distribution, critic_states = self._critic_network( tiled_observation, actions, tiled_prev_state.critic) value_estimate = value_distribution.mean() # Resample a single action of the sampled actions according to logits given # by the tempered Q-values. selected_action_idx = tfp.distributions.Categorical( probs=tf.nn.softmax(value_estimate / self._temperature_beta)).sample() selected_action = actions[selected_action_idx] # Select and return the RNN state that corresponds to the selected action. states = PolicyCriticRNNState(policy=policy_states, critic=critic_states) selected_state = tree.map_structure(lambda x: x[selected_action_idx], states) return selected_action, selected_state
def __call__(self, inputs: types.NestedTensor) -> tf.Tensor: # Inputs are of size [B, ...]. Here we tile them to be of shape [N, B, ...]. tiled_inputs = tf2_utils.tile_nested(inputs, self._num_action_samples) shape = tf.shape(tree.flatten(tiled_inputs)[0]) n, b = shape[0], shape[1] tf.debugging.assert_equal( n, self._num_action_samples, 'Internal Error. Unexpected tiled_inputs shape.') dummy_zeros_n_b = tf.zeros((n, b)) # Reshape to [N * B, ...]. merge = lambda x: snt.merge_leading_dims(x, 2) tiled_inputs = tree.map_structure(merge, tiled_inputs) tiled_actions = self._actor_network(tiled_inputs) # Compute Q-values and the resulting tempered probabilities. q = self._critic_network(tiled_inputs, tiled_actions) boltzmann_logits = q / self._beta boltzmann_logits = snt.split_leading_dim(boltzmann_logits, dummy_zeros_n_b, 2) # [B, N] boltzmann_logits = tf.transpose(boltzmann_logits, perm=(1, 0)) # Resample one action per batch according to the Boltzmann distribution. action_idx = tfp.distributions.Categorical( logits=boltzmann_logits).sample() # [B, 2], where the first column is 0, 1, 2,... corresponding to indices to # the batch dimension. action_idx = tf.stack((tf.range(b), action_idx), axis=1) tiled_actions = snt.split_leading_dim(tiled_actions, dummy_zeros_n_b, 2) action_dim = len(tiled_actions.get_shape().as_list()) tiled_actions = tf.transpose(tiled_actions, perm=[1, 0] + list(range(2, action_dim))) # [B, ...] action_sample = tf.gather_nd(tiled_actions, action_idx) return action_sample
def _step(self) -> Dict[str, tf.Tensor]: # Update target network online_variables = [ *self._critic_network.variables, *self._policy_network.variables, ] if self._prior_network is not None: online_variables += [*self._prior_network.variables] online_variables = tuple(online_variables) target_variables = [ *self._target_critic_network.variables, *self._target_policy_network.variables, ] if self._prior_network is not None: target_variables += [*self._target_prior_network.variables] target_variables = tuple(target_variables) # Make online -> target network update ops. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(online_variables, target_variables): dest.assign(src) self._num_steps.assign_add(1) # Get data from replay (dropping extras if any) and flip to `[T, B, ...]`. sample: reverb.ReplaySample = next(self._iterator) data = tf2_utils.batch_to_sequence(sample.data) observations, actions, rewards, discounts, extra = (data.observation, data.action, data.reward, data.discount, data.extras) online_target_pi_q = svg0_utils.OnlineTargetPiQ( online_pi=self._policy_network, online_q=self._critic_network, target_pi=self._target_policy_network, target_q=self._target_critic_network, num_samples=self._num_action_samples, online_prior=self._prior_network, target_prior=self._target_prior_network, ) with tf.GradientTape(persistent=True) as tape: step_outputs = svg0_utils.static_rnn( core=online_target_pi_q, inputs=(observations, actions), unroll_length=rewards.shape[0]) # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the # number of action samples taken. target_pi_samples = tf2_utils.batch_to_sequence( step_outputs.target_samples) # Tile observations to have shape [S, T+1, B,..]. tiled_observations = tf2_utils.tile_nested( observations, self._num_action_samples) # Finally compute target Q values on the new action samples. # Shape: [S, T+1, B, 1] target_q_target_pi_samples = snt.BatchApply( self._target_critic_network, 3)(tiled_observations, target_pi_samples) # Compute the value estimate by averaging over the action dimension. # Shape: [T+1, B, 1]. target_v_target_pi = tf.reduce_mean(target_q_target_pi_samples, axis=0) # Split the target V's into the target for learning # `value_function_target` and the bootstrap value. Shape: [T, B]. value_function_target = tf.squeeze(target_v_target_pi[:-1], axis=-1) # Shape: [B]. bootstrap_value = tf.squeeze(target_v_target_pi[-1], axis=-1) # When learning with a prior, add entropy terms to value targets. if self._prior_network is not None: value_function_target -= self._distillation_cost * tf.stop_gradient( step_outputs.analytic_kl_to_target[:-1]) bootstrap_value -= self._distillation_cost * tf.stop_gradient( step_outputs.analytic_kl_to_target[-1]) # Get target log probs and behavior log probs from rollout. # Shape: [T+1, B]. target_log_probs_behavior_actions = ( step_outputs.target_log_probs_behavior_actions) behavior_log_probs = extra['log_prob'] # Calculate importance weights. Shape: [T+1, B]. rhos = tf.exp(target_log_probs_behavior_actions - behavior_log_probs) # Filter the importance weights to mask out episode restarts. Ignore the # last action and consider the step type of the next step for masking. # Shape: [T, B]. episode_start_mask = tf2_utils.batch_to_sequence( sample.data.start_of_episode)[1:] rhos = svg0_utils.mask_out_restarting(rhos[:-1], episode_start_mask) # rhos = rhos[:-1] # Compute the log importance weights with a small value added for # stability. # Shape: [T, B] log_rhos = tf.math.log(rhos + _MIN_LOG_VAL) # Retrieve the target and online Q values and throw away the last action. # Shape: [T, B]. target_q_values = tf.squeeze(step_outputs.target_q[:-1], -1) online_q_values = tf.squeeze(step_outputs.online_q[:-1], -1) # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the # number of action samples taken. online_pi_samples = tf2_utils.batch_to_sequence( step_outputs.online_samples) target_q_online_pi_samples = snt.BatchApply( self._target_critic_network, 3)(tiled_observations, online_pi_samples) expected_q = tf.reduce_mean(tf.squeeze(target_q_online_pi_samples, -1), axis=0) # Flip online_log_probs to be of shape [S, T+1, B] and then compute # entropy by averaging over num samples. Final shape: [T+1, B]. online_log_probs = tf2_utils.batch_to_sequence( step_outputs.online_log_probs) sample_based_entropy = tf.reduce_mean(-online_log_probs, axis=0) retrace_outputs = continuous_retrace_ops.retrace_from_importance_weights( log_rhos=log_rhos, discounts=self._discount * discounts[:-1], rewards=rewards[:-1], q_values=target_q_values, values=value_function_target, bootstrap_value=bootstrap_value, lambda_=self._lambda, ) # Critic loss. Shape: [T, B]. critic_loss = 0.5 * tf.math.squared_difference( tf.stop_gradient(retrace_outputs.qs), online_q_values) # Policy loss- SVG0 with sample based entropy. Shape: [T, B] policy_loss = -(expected_q + self._entropy_regularizer_cost * sample_based_entropy) policy_loss = policy_loss[:-1] if self._prior_network is not None: # When training the prior, also add the per-timestep KL cost. policy_loss += (self._distillation_cost * step_outputs.analytic_kl_to_target[:-1]) # Ensure episode restarts are masked out when computing the losses. critic_loss = svg0_utils.mask_out_restarting( critic_loss, episode_start_mask) critic_loss = tf.reduce_mean(critic_loss) policy_loss = svg0_utils.mask_out_restarting( policy_loss, episode_start_mask) policy_loss = tf.reduce_mean(policy_loss) if self._prior_network is not None: prior_loss = step_outputs.analytic_kl_divergence[:-1] prior_loss = svg0_utils.mask_out_restarting( prior_loss, episode_start_mask) prior_loss = tf.reduce_mean(prior_loss) # Get trainable variables. policy_variables = self._policy_network.trainable_variables critic_variables = self._critic_network.trainable_variables # Compute gradients. policy_gradients = tape.gradient(policy_loss, policy_variables) critic_gradients = tape.gradient(critic_loss, critic_variables) if self._prior_network is not None: prior_variables = self._prior_network.trainable_variables prior_gradients = tape.gradient(prior_loss, prior_variables) # Delete the tape manually because of the persistent=True flag. del tape # Apply gradients. self._policy_optimizer.apply(policy_gradients, policy_variables) self._critic_optimizer.apply(critic_gradients, critic_variables) losses = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, } if self._prior_network is not None: self._prior_optimizer.apply(prior_gradients, prior_variables) losses['prior_loss'] = prior_loss # Losses to track. return losses
def _step(self, sample: reverb.ReplaySample) -> Dict[str, tf.Tensor]: # Transpose batch and sequence axes, i.e. [B, T, ...] to [T, B, ...]. sample = tf2_utils.batch_to_sequence(sample) observations = sample.observation actions = sample.action rewards = sample.reward discounts = sample.discount dtype = rewards.dtype # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=discounts.dtype) # Loss cumulants across time. These cannot be python mutable objects. critic_loss = 0. policy_loss = 0. # Each transition induces a policy loss, which we then weight using # the `policy_loss_coef_t`; shape [B], see https://arxiv.org/abs/2006.15134. # `policy_loss_coef` is a scalar average of these coefficients across # the batch and sequence length dimensions. policy_loss_coef = 0. per_device_batch_size = actions.shape[1] # Initialize recurrent states. critic_state = self._critic_network.initial_state( per_device_batch_size) target_critic_state = critic_state policy_state = self._policy_network.initial_state( per_device_batch_size) target_policy_state = policy_state with tf.GradientTape(persistent=True) as tape: for t in range(1, self._sequence_length): o_tm1 = tree.map_structure(operator.itemgetter(t - 1), observations) a_tm1 = tree.map_structure(operator.itemgetter(t - 1), actions) r_t = tree.map_structure(operator.itemgetter(t - 1), rewards) d_t = tree.map_structure(operator.itemgetter(t - 1), discounts) o_t = tree.map_structure(operator.itemgetter(t), observations) if t != 1: # By only updating the target critic state here we are forcing # the target critic to ignore observations[0]. Otherwise, the # target_critic will be unrolled for one more timestep than critic. # The smaller the sequence length, the more problematic this is: if # you use RNN on sequences of length 2, you would expect the code to # never use recurrent connections. But if you don't skip updating the # target_critic_state on observation[0] here, it won't be the case. _, target_critic_state = self._target_critic_network( o_tm1, a_tm1, target_critic_state) # ========================= Critic learning ============================ q_tm1, next_critic_state = self._critic_network( o_tm1, a_tm1, critic_state) target_action_distribution, target_policy_state = self._target_policy_network( o_t, target_policy_state) sampled_actions_t = target_action_distribution.sample( self._num_action_samples_td_learning) # [N, B, ...] tiled_o_t = tf2_utils.tile_nested( o_t, self._num_action_samples_td_learning) tiled_target_critic_state = tf2_utils.tile_nested( target_critic_state, self._num_action_samples_td_learning) # Compute the target critic's Q-value of the sampled actions. sampled_q_t, _ = snt.BatchApply(self._target_critic_network)( tiled_o_t, sampled_actions_t, tiled_target_critic_state) # Compute average logits by first reshaping them to [N, B, A] and then # normalizing them across atoms. new_shape = [ self._num_action_samples_td_learning, r_t.shape[0], -1 ] sampled_logits = tf.reshape(sampled_q_t.logits, new_shape) sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) # Construct the expected distributional value for bootstrapping. q_t = networks.DiscreteValuedDistribution( values=sampled_q_t.values, logits=averaged_logits) critic_loss_t = losses.categorical(q_tm1, r_t, discount * d_t, q_t) critic_loss_t = tf.reduce_mean(critic_loss_t) # ========================= Actor learning ============================= action_distribution_tm1, policy_state = self._policy_network( o_tm1, policy_state) q_tm1_mean = q_tm1.mean() # Compute the estimate of the value function based on # self._num_action_samples_policy_weight samples from the policy. tiled_o_tm1 = tf2_utils.tile_nested( o_tm1, self._num_action_samples_policy_weight) tiled_critic_state = tf2_utils.tile_nested( critic_state, self._num_action_samples_policy_weight) action_tm1 = action_distribution_tm1.sample( self._num_action_samples_policy_weight) tiled_z_tm1, _ = snt.BatchApply(self._critic_network)( tiled_o_tm1, action_tm1, tiled_critic_state) tiled_v_tm1 = tf.reshape( tiled_z_tm1.mean(), [self._num_action_samples_policy_weight, -1]) # Use mean, min, or max to aggregate Q(s, a_i), a_i ~ pi(s) into the # final estimate of the value function. if self._baseline_reduce_function == 'mean': v_tm1_estimate = tf.reduce_mean(tiled_v_tm1, axis=0) elif self._baseline_reduce_function == 'max': v_tm1_estimate = tf.reduce_max(tiled_v_tm1, axis=0) elif self._baseline_reduce_function == 'min': v_tm1_estimate = tf.reduce_min(tiled_v_tm1, axis=0) # Assert that action_distribution_tm1 is a batch of multivariate # distributions (in contrast to e.g. a [batch, action_size] collection # of 1d distributions). assert len(action_distribution_tm1.batch_shape) == 1 policy_loss_batch = -action_distribution_tm1.log_prob(a_tm1) advantage = q_tm1_mean - v_tm1_estimate if self._policy_improvement_modes == 'exp': policy_loss_coef_t = tf.math.minimum( tf.math.exp(advantage / self._beta), self._ratio_upper_bound) elif self._policy_improvement_modes == 'binary': policy_loss_coef_t = tf.cast(advantage > 0, dtype=dtype) elif self._policy_improvement_modes == 'all': # Regress against all actions (effectively pure BC). policy_loss_coef_t = 1. policy_loss_coef_t = tf.stop_gradient(policy_loss_coef_t) policy_loss_batch *= policy_loss_coef_t policy_loss_t = tf.reduce_mean(policy_loss_batch) critic_state = next_critic_state critic_loss += critic_loss_t policy_loss += policy_loss_t policy_loss_coef += tf.reduce_mean( policy_loss_coef_t) # For logging. # Divide by sequence length to get mean losses. critic_loss /= tf.cast(self._sequence_length, dtype=dtype) policy_loss /= tf.cast(self._sequence_length, dtype=dtype) policy_loss_coef /= tf.cast(self._sequence_length, dtype=dtype) # Compute gradients. critic_gradients = tape.gradient( critic_loss, self._critic_network.trainable_variables) policy_gradients = tape.gradient( policy_loss, self._policy_network.trainable_variables) # Delete the tape manually because of the persistent=True flag. del tape # Sync gradients across GPUs or TPUs. ctx = tf.distribute.get_replica_context() critic_gradients = ctx.all_reduce('mean', critic_gradients) policy_gradients = ctx.all_reduce('mean', policy_gradients) # Maybe clip gradients. if self._clipping: policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0] critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0] # Apply gradients. self._critic_optimizer.apply(critic_gradients, self._critic_network.trainable_variables) self._policy_optimizer.apply(policy_gradients, self._policy_network.trainable_variables) source_variables = (self._critic_network.variables + self._policy_network.variables) target_variables = (self._target_critic_network.variables + self._target_policy_network.variables) # Make online -> target network update ops. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(source_variables, target_variables): dest.assign(src) self._num_steps.assign_add(1) return { 'critic_loss': critic_loss, 'policy_loss': policy_loss, 'policy_loss_coef': policy_loss_coef, }