def __call__(self, inputs: types.NestedTensor) -> tf.Tensor: # Inputs are of size [B, ...]. Here we tile them to be of shape [N, B, ...]. tiled_inputs = tf2_utils.tile_nested(inputs, self._num_action_samples) shape = tf.shape(tree.flatten(tiled_inputs)[0]) n, b = shape[0], shape[1] tf.debugging.assert_equal( n, self._num_action_samples, 'Internal Error. Unexpected tiled_inputs shape.') dummy_zeros_n_b = tf.zeros((n, b)) # Reshape to [N * B, ...]. merge = lambda x: snt.merge_leading_dims(x, 2) tiled_inputs = tree.map_structure(merge, tiled_inputs) tiled_actions = self._actor_network(tiled_inputs) # Compute Q-values and the resulting tempered probabilities. q = self._critic_network(tiled_inputs, tiled_actions) boltzmann_logits = q / self._beta boltzmann_logits = snt.split_leading_dim(boltzmann_logits, dummy_zeros_n_b, 2) # [B, N] boltzmann_logits = tf.transpose(boltzmann_logits, perm=(1, 0)) # Resample one action per batch according to the Boltzmann distribution. action_idx = tfp.distributions.Categorical( logits=boltzmann_logits).sample() # [B, 2], where the first column is 0, 1, 2,... corresponding to indices to # the batch dimension. action_idx = tf.stack((tf.range(b), action_idx), axis=1) tiled_actions = snt.split_leading_dim(tiled_actions, dummy_zeros_n_b, 2) action_dim = len(tiled_actions.get_shape().as_list()) tiled_actions = tf.transpose(tiled_actions, perm=[1, 0] + list(range(2, action_dim))) # [B, ...] action_sample = tf.gather_nd(tiled_actions, action_idx) return action_sample
def _step(self) -> types.NestedTensor: # Update target network. online_policy_variables = self._policy_network.variables target_policy_variables = self._target_policy_network.variables online_critic_variables = ( *self._observation_network.variables, *self._critic_network.variables, ) target_critic_variables = ( *self._target_observation_network.variables, *self._target_critic_network.variables, ) # Make online policy -> target policy network update ops. if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: for src, dest in zip(online_policy_variables, target_policy_variables): dest.assign(src) # Make online critic -> target critic network update ops. if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: for src, dest in zip(online_critic_variables, target_critic_variables): dest.assign(src) self._num_steps.assign_add(1) # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. inputs = next(self._iterator) o_tm1, a_tm1, r_t, d_t, o_t = inputs.data # Get batch size and scalar dtype. batch_size = r_t.shape[0] # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=d_t.dtype) with tf.GradientTape(persistent=True) as tape: # Maybe transform the observation before feeding into policy and critic. # Transforming the observations this way at the start of the learning # step effectively means that the policy and critic share observation # network weights. o_tm1 = self._observation_network(o_tm1) # This stop_gradient prevents gradients to propagate into the target # observation network. In addition, since the online policy network is # evaluated at o_t, this also means the policy loss does not influence # the observation network training. o_t = tf.stop_gradient(self._target_observation_network(o_t)) # Get online and target action distributions from policy networks. online_action_distribution = self._policy_network(o_t) target_action_distribution = self._target_policy_network(o_t) # Sample actions to evaluate policy; of size [N, B, ...]. sampled_actions = target_action_distribution.sample( self._num_samples) # Tile embedded observations to feed into the target critic network. # Note: this is more efficient than tiling before the embedding layer. tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] # Compute target-estimated distributional value of sampled actions at o_t. sampled_q_t_distributions = self._target_critic_network( # Merge batch dimensions; to shape [N*B, ...]. snt.merge_leading_dims(tiled_o_t, num_dims=2), snt.merge_leading_dims(sampled_actions, num_dims=2)) # Compute average logits by first reshaping them and normalizing them # across atoms. new_shape = [self._num_samples, batch_size, -1] # [N, B, A] sampled_logits = tf.reshape(sampled_q_t_distributions.logits, new_shape) sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) # Construct the expected distributional value for bootstrapping. q_t_distribution = networks.DiscreteValuedDistribution( values=sampled_q_t_distributions.values, logits=averaged_logits) # Compute online critic value distribution of a_tm1 in state o_tm1. q_tm1_distribution = self._critic_network(o_tm1, a_tm1) # Compute critic distributional loss. critic_loss = losses.categorical(q_tm1_distribution, r_t, discount * d_t, q_t_distribution) critic_loss = tf.reduce_mean(critic_loss) # Compute Q-values of sampled actions and reshape to [N, B]. sampled_q_values = sampled_q_t_distributions.mean() sampled_q_values = tf.reshape(sampled_q_values, (self._num_samples, -1)) # Compute MPO policy loss. policy_loss, policy_stats = self._policy_loss_module( online_action_distribution=online_action_distribution, target_action_distribution=target_action_distribution, actions=sampled_actions, q_values=sampled_q_values) # For clarity, explicitly define which variables are trained by which loss. critic_trainable_variables = ( # In this agent, the critic loss trains the observation network. self._observation_network.trainable_variables + self._critic_network.trainable_variables) policy_trainable_variables = self._policy_network.trainable_variables # The following are the MPO dual variables, stored in the loss module. dual_trainable_variables = self._policy_loss_module.trainable_variables # Compute gradients. critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) policy_gradients, dual_gradients = tape.gradient( policy_loss, (policy_trainable_variables, dual_trainable_variables)) # Delete the tape manually because of the persistent=True flag. del tape # Maybe clip gradients. if self._clipping: policy_gradients = tuple( tf.clip_by_global_norm(policy_gradients, 40.)[0]) critic_gradients = tuple( tf.clip_by_global_norm(critic_gradients, 40.)[0]) # Apply gradients. self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) # Losses to track. fetches = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, } fetches.update(policy_stats) # Log MPO stats. return fetches
def _step(self) -> types.NestedTensor: # Update target network. online_policy_variables = self._policy_network.variables target_policy_variables = self._target_policy_network.variables online_critic_variables = ( *self._observation_network.variables, *self._critic_network.variables, ) target_critic_variables = ( *self._target_observation_network.variables, *self._target_critic_network.variables, ) # Make online policy -> target policy network update ops. if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: for src, dest in zip(online_policy_variables, target_policy_variables): dest.assign(src) # Make online critic -> target critic network update ops. if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: for src, dest in zip(online_critic_variables, target_critic_variables): dest.assign(src) self._num_steps.assign_add(1) # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. inputs = next(self._iterator) o_tm1, a_tm1, r_t, d_t, o_t = inputs.data with tf.GradientTape(persistent=True) as tape: # Maybe transform the observation before feeding into policy and critic. # Transforming the observations this way at the start of the learning # step effectively means that the policy and critic share observation # network weights. o_tm1 = self._observation_network(o_tm1) # This stop_gradient prevents gradients to propagate into the target # observation network. In addition, since the online policy network is # evaluated at o_t, this also means the policy loss does not influence # the observation network training. o_t = tf.stop_gradient(self._target_observation_network(o_t)) # Get online and target action distributions from policy networks. online_action_distribution = self._policy_network(o_t) target_action_distribution = self._target_policy_network(o_t) # Sample actions to evaluate policy; of size [N, B, ...]. sampled_actions = target_action_distribution.sample( self._num_samples) # Tile embedded observations to feed into the target critic network. # Note: this is more efficient than tiling before the embedding layer. tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] # Compute target-estimated distributional value of sampled actions at o_t. sampled_q_t_all = self._target_critic_network( # Merge batch dimensions; to shape [N*B, ...]. snt.merge_leading_dims(tiled_o_t, num_dims=2), snt.merge_leading_dims(sampled_actions, num_dims=2)) # Compute online critic value distribution of a_tm1 in state o_tm1. q_tm1_all = self._critic_network(o_tm1, a_tm1) # Compute rewards for objectives with defined reward_fn reward_stats = {} r_t_all = [] for objective in self._reward_objectives: r = objective.reward_fn(o_tm1, a_tm1, r_t) reward_stats['{}_reward'.format( objective.name)] = tf.reduce_mean(r) r_t_all.append(r) r_t_all = tf.stack(r_t_all, axis=-1) r_t_all.get_shape().assert_has_rank(2) # [B, C] if isinstance(sampled_q_t_all, list): # Distributional critics critic_loss, sampled_q_t = _compute_distributional_critic_loss( sampled_q_t_all, q_tm1_all, r_t_all, d_t, self._discount, self._num_samples) else: critic_loss, sampled_q_t = _compute_critic_loss( sampled_q_t_all, q_tm1_all, r_t_all, d_t, self._discount, self._num_samples, self._num_critic_heads) # Add sampled Q-values for objectives with defined qvalue_fn sampled_q_t_k = [sampled_q_t] for objective in self._qvalue_objectives: sampled_q_t_k.append( tf.expand_dims(tf.stop_gradient( objective.qvalue_fn(sampled_actions, sampled_q_t)), axis=-1)) sampled_q_t_k = tf.concat(sampled_q_t_k, axis=-1) # [N, B, K] # Compute MPO policy loss. policy_loss, policy_stats = self._policy_loss_module( online_action_distribution=online_action_distribution, target_action_distribution=target_action_distribution, actions=sampled_actions, q_values=sampled_q_t_k) # For clarity, explicitly define which variables are trained by which loss. critic_trainable_variables = ( # In this agent, the critic loss trains the observation network. self._observation_network.trainable_variables + self._critic_network.trainable_variables) policy_trainable_variables = self._policy_network.trainable_variables # The following are the MPO dual variables, stored in the loss module. dual_trainable_variables = self._policy_loss_module.trainable_variables # Compute gradients. critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) policy_gradients, dual_gradients = tape.gradient( policy_loss, (policy_trainable_variables, dual_trainable_variables)) # Delete the tape manually because of the persistent=True flag. del tape # Maybe clip gradients. if self._clipping: policy_gradients = tuple( tf.clip_by_global_norm(policy_gradients, 40.)[0]) critic_gradients = tuple( tf.clip_by_global_norm(critic_gradients, 40.)[0]) # Apply gradients. self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) # Losses to track. fetches = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, } fetches.update(policy_stats) # Log MPO stats. fetches.update(reward_stats) # Log reward stats. return fetches
def _step(self) -> types.NestedTensor: # Update target network. online_policy_variables = self._policy_network.variables target_policy_variables = self._target_policy_network.variables online_critic_variables = ( *self._observation_network.variables, *self._critic_network.variables, ) target_critic_variables = ( *self._target_observation_network.variables, *self._target_critic_network.variables, ) # Make online policy -> target policy network update ops. if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: for src, dest in zip(online_policy_variables, target_policy_variables): dest.assign(src) # Make online critic -> target critic network update ops. if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: for src, dest in zip(online_critic_variables, target_critic_variables): dest.assign(src) self._num_steps.assign_add(1) # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. inputs = next(self._iterator) o_tm1, a_tm1, r_t, d_t, o_t = inputs.data # Get batch size and scalar dtype. batch_size = r_t.shape[0] # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=d_t.dtype) with tf.GradientTape(persistent=True) as tape: # Maybe transform the observation before feeding into policy and critic. # Transforming the observations this way at the start of the learning # step effectively means that the policy and critic share observation # network weights. o_tm1 = self._observation_network(o_tm1) # This stop_gradient prevents gradients to propagate into the target # observation network. In addition, since the online policy network is # evaluated at o_t, this also means the policy loss does not influence # the observation network training. o_t = tf.stop_gradient(self._target_observation_network(o_t)) # Get online and target action distributions from policy networks. online_action_distribution = self._policy_network(o_t) target_action_distribution = self._target_policy_network(o_t) # Sample actions to evaluate policy; of size [N, B, ...]. sampled_actions = target_action_distribution.sample( self._num_samples) # Tile embedded observations to feed into the target critic network. # Note: this is more efficient than tiling before the embedding layer. tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] # Compute target-estimated distributional value of sampled actions at o_t. sampled_q_t_all = self._target_critic_network( # Merge batch dimensions; to shape [N*B, ...]. snt.merge_leading_dims(tiled_o_t, num_dims=2), snt.merge_leading_dims(sampled_actions, num_dims=2)) # Compute online critic value distribution of a_tm1 in state o_tm1. q_tm1_all = self._critic_network(o_tm1, a_tm1) # Compute rewards for objectives with defined reward_fn reward_stats = {} r_t_all = [] for objective in self._objectives: if hasattr(objective, 'reward_fn'): r = objective.reward_fn(o_tm1, a_tm1, r_t) reward_stats['{}_reward'.format( objective.name)] = tf.reduce_mean(r) r_t_all.append(r) r_t_all = tf.stack(r_t_all, axis=-1) r_t_all.get_shape().assert_has_rank(2) # [B, C] if isinstance(sampled_q_t_all, list): # Distributional critics # Compute average logits by first reshaping them and normalizing them # across atoms. critic_losses = [] sampled_q_ts = [] for idx, (sampled_q_t_distributions, q_tm1_distribution) in enumerate( zip(sampled_q_t_all, q_tm1_all)): # Compute loss for distributional critic for objective c sampled_logits = tf.reshape( sampled_q_t_distributions.logits, [self._num_samples, batch_size, -1]) # [N, B, A] sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) # Construct the expected distributional value for bootstrapping. q_t_distribution = networks.DiscreteValuedDistribution( values=sampled_q_t_distributions.values, logits=averaged_logits) # Compute critic distributional loss. critic_loss = losses.categorical(q_tm1_distribution, r_t_all[:, idx], discount * d_t, q_t_distribution) critic_losses.append(tf.reduce_mean(critic_loss)) # Compute Q-values of sampled actions and reshape to [N, B]. sampled_q_ts.append( tf.reshape(sampled_q_t_distributions.mean(), (self._num_samples, -1))) critic_loss = tf.reduce_mean(critic_losses) sampled_q_t = tf.stack(sampled_q_ts, axis=-1) # [N, B, C] else: # Reshape Q-value samples back to original batch dimensions and average # them to compute the TD-learning bootstrap target. sampled_q_t = tf.reshape(sampled_q_t_all, (self._num_samples, batch_size, self._num_critic_heads)) # [N,B,C] q_t = tf.reduce_mean(sampled_q_t, axis=0) # [B, C] # Flatten q_t and q_tm1; necessary for trfl.td_learning q_t = tf.reshape(q_t, [-1]) # [B*C] q_tm1 = tf.reshape(q_tm1_all, [-1]) # [B*C] # Flatten r_t_all; necessary for trfl.td_learning r_t_all = tf.reshape(r_t_all, [-1]) # [B*C] # Broadcast and then flatten d_t, to match shape of q_t and q_tm1 d_t = tf.tile(d_t, [self._num_critic_heads]) # [B*C] # Critic loss. critic_loss = trfl.td_learning(q_tm1, r_t_all, discount * d_t, q_t).loss critic_loss = tf.reduce_mean(critic_loss) # Add sampled Q-values for objectives with defined objective_fn sampled_q_idx = 0 sampled_q_t_k = [] for objective in self._objectives: if hasattr(objective, 'reward_fn'): sampled_q_t_k.append( tf.stop_gradient(sampled_q_t[..., sampled_q_idx])) sampled_q_idx += 1 if hasattr(objective, 'objective_fn'): sampled_q_t_k.append( tf.stop_gradient( objective.objective_fn(sampled_actions, sampled_q_t))) sampled_q_t_k = tf.stack(sampled_q_t_k, axis=-1) # [N, B, K] # Compute MPO policy loss. policy_loss, policy_stats = self._policy_loss_module( online_action_distribution=online_action_distribution, target_action_distribution=target_action_distribution, actions=sampled_actions, q_values=sampled_q_t_k) # For clarity, explicitly define which variables are trained by which loss. critic_trainable_variables = ( # In this agent, the critic loss trains the observation network. self._observation_network.trainable_variables + self._critic_network.trainable_variables) policy_trainable_variables = self._policy_network.trainable_variables # The following are the MPO dual variables, stored in the loss module. dual_trainable_variables = self._policy_loss_module.trainable_variables # Compute gradients. critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) policy_gradients, dual_gradients = tape.gradient( policy_loss, (policy_trainable_variables, dual_trainable_variables)) # Delete the tape manually because of the persistent=True flag. del tape # Maybe clip gradients. if self._clipping: policy_gradients = tuple( tf.clip_by_global_norm(policy_gradients, 40.)[0]) critic_gradients = tuple( tf.clip_by_global_norm(critic_gradients, 40.)[0]) # Apply gradients. self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) # Losses to track. fetches = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, } fetches.update(policy_stats) # Log MPO stats. fetches.update(reward_stats) # Log reward stats. return fetches
def _step(self, inputs: reverb.ReplaySample) -> types.NestedTensor: # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. o_tm1, a_tm1, r_t, d_t, o_t = (inputs.data.observation, inputs.data.action, inputs.data.reward, inputs.data.discount, inputs.data.next_observation) # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=d_t.dtype) with tf.GradientTape(persistent=True) as tape: # Maybe transform the observation before feeding into policy and critic. # Transforming the observations this way at the start of the learning # step effectively means that the policy and critic share observation # network weights. o_tm1 = self._observation_network(o_tm1) # This stop_gradient prevents gradients to propagate into the target # observation network. In addition, since the online policy network is # evaluated at o_t, this also means the policy loss does not influence # the observation network training. o_t = tf.stop_gradient(self._target_observation_network(o_t)) # Get online and target action distributions from policy networks. online_action_distribution = self._policy_network(o_t) target_action_distribution = self._target_policy_network(o_t) # Sample actions to evaluate policy; of size [N, B, ...]. sampled_actions = target_action_distribution.sample( self._num_samples) # Tile embedded observations to feed into the target critic network. # Note: this is more efficient than tiling before the embedding layer. tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] # Compute target-estimated distributional value of sampled actions at o_t. sampled_q_t_distributions = self._target_critic_network( # Merge batch dimensions; to shape [N*B, ...]. snt.merge_leading_dims(tiled_o_t, num_dims=2), snt.merge_leading_dims(sampled_actions, num_dims=2)) # Compute online critic value distribution of a_tm1 in state o_tm1. q_tm1_distribution = self._critic_network(o_tm1, a_tm1) # [B, ...] # Get the return distributions used in the policy evaluation bootstrap. if self._policy_evaluation_config.evaluate_stochastic_policy: z_distributions = sampled_q_t_distributions num_joint_samples = self._num_samples else: z_distributions = self._target_critic_network( o_t, target_action_distribution.mean()) num_joint_samples = 1 num_value_samples = self._policy_evaluation_config.num_value_samples num_joint_samples *= num_value_samples z_samples = z_distributions.sample(num_value_samples) z_samples = tf.reshape(z_samples, (num_joint_samples, -1, 1)) # Expand dims of reward and discount tensors. reward = r_t[..., tf.newaxis] # [B, 1] full_discount = discount * d_t[..., tf.newaxis] target_q = reward + full_discount * z_samples # [N, B, 1] target_q = tf.stop_gradient(target_q) # Compute sample-based cross-entropy. log_probs_q = q_tm1_distribution.log_prob(target_q) # [N, B, 1] critic_loss = -tf.reduce_mean(log_probs_q, axis=0) # [B, 1] critic_loss = tf.reduce_mean(critic_loss) # Compute Q-values of sampled actions and reshape to [N, B]. sampled_q_values = sampled_q_t_distributions.mean() sampled_q_values = tf.reshape(sampled_q_values, (self._num_samples, -1)) # Compute MPO policy loss. policy_loss, policy_stats = self._policy_loss_module( online_action_distribution=online_action_distribution, target_action_distribution=target_action_distribution, actions=sampled_actions, q_values=sampled_q_values) policy_loss = tf.reduce_mean(policy_loss) # For clarity, explicitly define which variables are trained by which loss. critic_trainable_variables = ( # In this agent, the critic loss trains the observation network. self._observation_network.trainable_variables + self._critic_network.trainable_variables) policy_trainable_variables = self._policy_network.trainable_variables # The following are the MPO dual variables, stored in the loss module. dual_trainable_variables = self._policy_loss_module.trainable_variables # Compute gradients. critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) policy_gradients, dual_gradients = tape.gradient( policy_loss, (policy_trainable_variables, dual_trainable_variables)) # Delete the tape manually because of the persistent=True flag. del tape # Maybe clip gradients. if self._clipping: policy_gradients = tuple( tf.clip_by_global_norm(policy_gradients, 40.)[0]) critic_gradients = tuple( tf.clip_by_global_norm(critic_gradients, 40.)[0]) # Apply gradients. self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) # Losses to track. fetches = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, } # Log MPO stats. fetches.update(policy_stats) return fetches
def _forward(self, inputs: Any) -> None: """Trainer forward pass Args: inputs (Any): input data from the data table (transitions) """ # Convert to sequence data data = tf2_utils.batch_to_sequence(inputs.data) # Unpack input data as follows: observations, actions, rewards, discounts, extras = ( data.observations, data.actions, data.rewards, data.discounts, data.extras, ) # transform observation using observation networks observations_trans = self._transform_observations(observations) # Get log_probs. log_probs = extras["log_probs"] # Store losses. policy_losses: Dict[str, Any] = {} critic_losses: Dict[str, Any] = {} with tf.GradientTape(persistent=True) as tape: for agent in self._agents: action, reward, discount, behaviour_log_prob = ( actions[agent], rewards[agent], discounts[agent], log_probs[agent], ) actor_observation = observations_trans[agent] critic_observation = self._get_critic_feed( observations_trans, agent) # Chop off final timestep for bootstrapping value reward = reward[:-1] discount = discount[:-1] # Get agent network agent_key = agent.split( "_")[0] if self._shared_weights else agent policy_network = self._policy_networks[agent_key] critic_network = self._critic_networks[agent_key] # Reshape inputs. dims = actor_observation.shape[:2] actor_observation = snt.merge_leading_dims(actor_observation, num_dims=2) critic_observation = snt.merge_leading_dims(critic_observation, num_dims=2) policy = policy_network(actor_observation) values = critic_network(critic_observation) # Reshape the outputs. policy = tfd.BatchReshape(policy, batch_shape=dims, name="policy") values = tf.reshape(values, dims, name="value") # Values along the sequence T. bootstrap_value = values[-1] state_values = values[:-1] # Generalized Return Estimation td_loss, td_lambda_extra = trfl.td_lambda( state_values=state_values, rewards=reward, pcontinues=discount, bootstrap_value=bootstrap_value, lambda_=self._lambda_gae, name="CriticLoss", ) # Do not use the loss provided by td_lambda as they sum the losses over # the sequence length rather than averaging them. critic_loss = self._baseline_cost * tf.reduce_mean( tf.square(td_lambda_extra.temporal_differences), name="CriticLoss") # Compute importance sampling weights: current policy / behavior policy. log_rhos = policy.log_prob(action) - behaviour_log_prob importance_ratio = tf.exp(log_rhos)[:-1] clipped_importance_ratio = tf.clip_by_value( importance_ratio, 1.0 - self._clipping_epsilon, 1.0 + self._clipping_epsilon, ) # Generalized Advantage Estimation gae = tf.stop_gradient(td_lambda_extra.temporal_differences) mean, variance = tf.nn.moments(gae, axes=[0, 1], keepdims=True) normalized_gae = (gae - mean) / tf.sqrt(variance) policy_gradient_loss = tf.reduce_mean( -tf.minimum( tf.multiply(importance_ratio, normalized_gae), tf.multiply(clipped_importance_ratio, normalized_gae), ), name="PolicyGradientLoss", ) # Entropy regularization. Only implemented for categorical dist. try: policy_entropy = tf.reduce_mean(policy.entropy()) except NotImplementedError: policy_entropy = tf.convert_to_tensor(0.0) entropy_loss = -self._entropy_cost * policy_entropy # Combine weighted sum of actor & entropy regularization. policy_loss = policy_gradient_loss + entropy_loss policy_losses[agent] = policy_loss critic_losses[agent] = critic_loss self.policy_losses = policy_losses self.critic_losses = critic_losses self.tape = tape
def _step(self) -> types.Nest: # Update target network. online_policy_variables = self._policy_network.variables target_policy_variables = self._target_policy_network.variables online_critic_variables = ( *self._observation_network.variables, *self._critic_network.variables, ) target_critic_variables = ( *self._target_observation_network.variables, *self._target_critic_network.variables, ) # Make online policy -> target policy network update ops. if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: for src, dest in zip(online_policy_variables, target_policy_variables): dest.assign(src) # Make online critic -> target critic network update ops. if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: for src, dest in zip(online_critic_variables, target_critic_variables): dest.assign(src) # Increment number of learner steps for periodic update bookkeeping. self._num_steps.assign_add(1) # Get next batch of data. inputs = next(self._iterator) # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. transitions: types.Transition = inputs.data # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=transitions.discount.dtype) with tf.GradientTape(persistent=True) as tape: # Maybe transform the observation before feeding into policy and critic. # Transforming the observations this way at the start of the learning # step effectively means that the policy and critic share observation # network weights. o_tm1 = self._observation_network(transitions.observation) # This stop_gradient prevents gradients to propagate into the target # observation network. In addition, since the online policy network is # evaluated at o_t, this also means the policy loss does not influence # the observation network training. o_t = tf.stop_gradient( self._target_observation_network(transitions.next_observation)) # Get action distributions from policy networks. online_action_distribution = self._policy_network(o_t) target_action_distribution = self._target_policy_network(o_t) # Get sampled actions to evaluate policy; of size [N, B, ...]. sampled_actions = target_action_distribution.sample( self._num_samples) tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] # Compute the target critic's Q-value of the sampled actions in state o_t. sampled_q_t = self._target_critic_network( # Merge batch dimensions; to shape [N*B, ...]. snt.merge_leading_dims(tiled_o_t, num_dims=2), snt.merge_leading_dims(sampled_actions, num_dims=2)) # Reshape Q-value samples back to original batch dimensions and average # them to compute the TD-learning bootstrap target. sampled_q_t = tf.reshape(sampled_q_t, (self._num_samples, -1)) # [N, B] q_t = tf.reduce_mean(sampled_q_t, axis=0) # [B] # Compute online critic value of a_tm1 in state o_tm1. q_tm1 = self._critic_network(o_tm1, transitions.action) # [B, 1] q_tm1 = tf.squeeze(q_tm1, axis=-1) # [B]; necessary for trfl.td_learning. # Critic loss. critic_loss = trfl.td_learning(q_tm1, transitions.reward, discount * transitions.discount, q_t).loss critic_loss = tf.reduce_mean(critic_loss) # Actor learning. policy_loss, policy_stats = self._policy_loss_module( online_action_distribution=online_action_distribution, target_action_distribution=target_action_distribution, actions=sampled_actions, q_values=sampled_q_t) # For clarity, explicitly define which variables are trained by which loss. critic_trainable_variables = ( # In this agent, the critic loss trains the observation network. self._observation_network.trainable_variables + self._critic_network.trainable_variables) policy_trainable_variables = self._policy_network.trainable_variables # The following are the MPO dual variables, stored in the loss module. dual_trainable_variables = self._policy_loss_module.trainable_variables # Compute gradients. critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) policy_gradients, dual_gradients = tape.gradient( policy_loss, (policy_trainable_variables, dual_trainable_variables)) # Delete the tape manually because of the persistent=True flag. del tape # Maybe clip gradients. if self._clipping: policy_gradients = tuple( tf.clip_by_global_norm(policy_gradients, 40.)[0]) critic_gradients = tuple( tf.clip_by_global_norm(critic_gradients, 40.)[0]) # Apply gradients. self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) # Losses to track. fetches = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, } fetches.update(policy_stats) # Log MPO stats. return fetches
def combine_dim(tensor: tf.Tensor) -> tf.Tensor: dims = tensor.shape[:2] return snt.merge_leading_dims(tensor, num_dims=2), dims