def _step(self, sample: reverb.ReplaySample) -> Dict[str, tf.Tensor]: # Transpose batch and sequence axes, i.e. [B, T, ...] to [T, B, ...]. sample = tf2_utils.batch_to_sequence(sample) observations = sample.observation actions = sample.action rewards = sample.reward discounts = sample.discount dtype = rewards.dtype # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=discounts.dtype) # Loss cumulants across time. These cannot be python mutable objects. critic_loss = 0. policy_loss = 0. # Each transition induces a policy loss, which we then weight using # the `policy_loss_coef_t`; shape [B], see https://arxiv.org/abs/2006.15134. # `policy_loss_coef` is a scalar average of these coefficients across # the batch and sequence length dimensions. policy_loss_coef = 0. per_device_batch_size = actions.shape[1] # Initialize recurrent states. critic_state = self._critic_network.initial_state( per_device_batch_size) target_critic_state = critic_state policy_state = self._policy_network.initial_state( per_device_batch_size) target_policy_state = policy_state with tf.GradientTape(persistent=True) as tape: for t in range(1, self._sequence_length): o_tm1 = tree.map_structure(operator.itemgetter(t - 1), observations) a_tm1 = tree.map_structure(operator.itemgetter(t - 1), actions) r_t = tree.map_structure(operator.itemgetter(t - 1), rewards) d_t = tree.map_structure(operator.itemgetter(t - 1), discounts) o_t = tree.map_structure(operator.itemgetter(t), observations) if t != 1: # By only updating the target critic state here we are forcing # the target critic to ignore observations[0]. Otherwise, the # target_critic will be unrolled for one more timestep than critic. # The smaller the sequence length, the more problematic this is: if # you use RNN on sequences of length 2, you would expect the code to # never use recurrent connections. But if you don't skip updating the # target_critic_state on observation[0] here, it won't be the case. _, target_critic_state = self._target_critic_network( o_tm1, a_tm1, target_critic_state) # ========================= Critic learning ============================ q_tm1, next_critic_state = self._critic_network( o_tm1, a_tm1, critic_state) target_action_distribution, target_policy_state = self._target_policy_network( o_t, target_policy_state) sampled_actions_t = target_action_distribution.sample( self._num_action_samples_td_learning) # [N, B, ...] tiled_o_t = tf2_utils.tile_nested( o_t, self._num_action_samples_td_learning) tiled_target_critic_state = tf2_utils.tile_nested( target_critic_state, self._num_action_samples_td_learning) # Compute the target critic's Q-value of the sampled actions. sampled_q_t, _ = snt.BatchApply(self._target_critic_network)( tiled_o_t, sampled_actions_t, tiled_target_critic_state) # Compute average logits by first reshaping them to [N, B, A] and then # normalizing them across atoms. new_shape = [ self._num_action_samples_td_learning, r_t.shape[0], -1 ] sampled_logits = tf.reshape(sampled_q_t.logits, new_shape) sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) # Construct the expected distributional value for bootstrapping. q_t = networks.DiscreteValuedDistribution( values=sampled_q_t.values, logits=averaged_logits) critic_loss_t = losses.categorical(q_tm1, r_t, discount * d_t, q_t) critic_loss_t = tf.reduce_mean(critic_loss_t) # ========================= Actor learning ============================= action_distribution_tm1, policy_state = self._policy_network( o_tm1, policy_state) q_tm1_mean = q_tm1.mean() # Compute the estimate of the value function based on # self._num_action_samples_policy_weight samples from the policy. tiled_o_tm1 = tf2_utils.tile_nested( o_tm1, self._num_action_samples_policy_weight) tiled_critic_state = tf2_utils.tile_nested( critic_state, self._num_action_samples_policy_weight) action_tm1 = action_distribution_tm1.sample( self._num_action_samples_policy_weight) tiled_z_tm1, _ = snt.BatchApply(self._critic_network)( tiled_o_tm1, action_tm1, tiled_critic_state) tiled_v_tm1 = tf.reshape( tiled_z_tm1.mean(), [self._num_action_samples_policy_weight, -1]) # Use mean, min, or max to aggregate Q(s, a_i), a_i ~ pi(s) into the # final estimate of the value function. if self._baseline_reduce_function == 'mean': v_tm1_estimate = tf.reduce_mean(tiled_v_tm1, axis=0) elif self._baseline_reduce_function == 'max': v_tm1_estimate = tf.reduce_max(tiled_v_tm1, axis=0) elif self._baseline_reduce_function == 'min': v_tm1_estimate = tf.reduce_min(tiled_v_tm1, axis=0) # Assert that action_distribution_tm1 is a batch of multivariate # distributions (in contrast to e.g. a [batch, action_size] collection # of 1d distributions). assert len(action_distribution_tm1.batch_shape) == 1 policy_loss_batch = -action_distribution_tm1.log_prob(a_tm1) advantage = q_tm1_mean - v_tm1_estimate if self._policy_improvement_modes == 'exp': policy_loss_coef_t = tf.math.minimum( tf.math.exp(advantage / self._beta), self._ratio_upper_bound) elif self._policy_improvement_modes == 'binary': policy_loss_coef_t = tf.cast(advantage > 0, dtype=dtype) elif self._policy_improvement_modes == 'all': # Regress against all actions (effectively pure BC). policy_loss_coef_t = 1. policy_loss_coef_t = tf.stop_gradient(policy_loss_coef_t) policy_loss_batch *= policy_loss_coef_t policy_loss_t = tf.reduce_mean(policy_loss_batch) critic_state = next_critic_state critic_loss += critic_loss_t policy_loss += policy_loss_t policy_loss_coef += tf.reduce_mean( policy_loss_coef_t) # For logging. # Divide by sequence length to get mean losses. critic_loss /= tf.cast(self._sequence_length, dtype=dtype) policy_loss /= tf.cast(self._sequence_length, dtype=dtype) policy_loss_coef /= tf.cast(self._sequence_length, dtype=dtype) # Compute gradients. critic_gradients = tape.gradient( critic_loss, self._critic_network.trainable_variables) policy_gradients = tape.gradient( policy_loss, self._policy_network.trainable_variables) # Delete the tape manually because of the persistent=True flag. del tape # Sync gradients across GPUs or TPUs. ctx = tf.distribute.get_replica_context() critic_gradients = ctx.all_reduce('mean', critic_gradients) policy_gradients = ctx.all_reduce('mean', policy_gradients) # Maybe clip gradients. if self._clipping: policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0] critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0] # Apply gradients. self._critic_optimizer.apply(critic_gradients, self._critic_network.trainable_variables) self._policy_optimizer.apply(policy_gradients, self._policy_network.trainable_variables) source_variables = (self._critic_network.variables + self._policy_network.variables) target_variables = (self._target_critic_network.variables + self._target_policy_network.variables) # Make online -> target network update ops. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(source_variables, target_variables): dest.assign(src) self._num_steps.assign_add(1) return { 'critic_loss': critic_loss, 'policy_loss': policy_loss, 'policy_loss_coef': policy_loss_coef, }
def _step(self) -> Dict[str, tf.Tensor]: # Update target network online_variables = [ *self._critic_network.variables, *self._policy_network.variables, ] if self._prior_network is not None: online_variables += [*self._prior_network.variables] online_variables = tuple(online_variables) target_variables = [ *self._target_critic_network.variables, *self._target_policy_network.variables, ] if self._prior_network is not None: target_variables += [*self._target_prior_network.variables] target_variables = tuple(target_variables) # Make online -> target network update ops. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(online_variables, target_variables): dest.assign(src) self._num_steps.assign_add(1) # Get data from replay (dropping extras if any) and flip to `[T, B, ...]`. sample: reverb.ReplaySample = next(self._iterator) data = tf2_utils.batch_to_sequence(sample.data) observations, actions, rewards, discounts, extra = (data.observation, data.action, data.reward, data.discount, data.extras) online_target_pi_q = svg0_utils.OnlineTargetPiQ( online_pi=self._policy_network, online_q=self._critic_network, target_pi=self._target_policy_network, target_q=self._target_critic_network, num_samples=self._num_action_samples, online_prior=self._prior_network, target_prior=self._target_prior_network, ) with tf.GradientTape(persistent=True) as tape: step_outputs = svg0_utils.static_rnn( core=online_target_pi_q, inputs=(observations, actions), unroll_length=rewards.shape[0]) # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the # number of action samples taken. target_pi_samples = tf2_utils.batch_to_sequence( step_outputs.target_samples) # Tile observations to have shape [S, T+1, B,..]. tiled_observations = tf2_utils.tile_nested( observations, self._num_action_samples) # Finally compute target Q values on the new action samples. # Shape: [S, T+1, B, 1] target_q_target_pi_samples = snt.BatchApply( self._target_critic_network, 3)(tiled_observations, target_pi_samples) # Compute the value estimate by averaging over the action dimension. # Shape: [T+1, B, 1]. target_v_target_pi = tf.reduce_mean(target_q_target_pi_samples, axis=0) # Split the target V's into the target for learning # `value_function_target` and the bootstrap value. Shape: [T, B]. value_function_target = tf.squeeze(target_v_target_pi[:-1], axis=-1) # Shape: [B]. bootstrap_value = tf.squeeze(target_v_target_pi[-1], axis=-1) # When learning with a prior, add entropy terms to value targets. if self._prior_network is not None: value_function_target -= self._distillation_cost * tf.stop_gradient( step_outputs.analytic_kl_to_target[:-1]) bootstrap_value -= self._distillation_cost * tf.stop_gradient( step_outputs.analytic_kl_to_target[-1]) # Get target log probs and behavior log probs from rollout. # Shape: [T+1, B]. target_log_probs_behavior_actions = ( step_outputs.target_log_probs_behavior_actions) behavior_log_probs = extra['log_prob'] # Calculate importance weights. Shape: [T+1, B]. rhos = tf.exp(target_log_probs_behavior_actions - behavior_log_probs) # Filter the importance weights to mask out episode restarts. Ignore the # last action and consider the step type of the next step for masking. # Shape: [T, B]. episode_start_mask = tf2_utils.batch_to_sequence( sample.data.start_of_episode)[1:] rhos = svg0_utils.mask_out_restarting(rhos[:-1], episode_start_mask) # rhos = rhos[:-1] # Compute the log importance weights with a small value added for # stability. # Shape: [T, B] log_rhos = tf.math.log(rhos + _MIN_LOG_VAL) # Retrieve the target and online Q values and throw away the last action. # Shape: [T, B]. target_q_values = tf.squeeze(step_outputs.target_q[:-1], -1) online_q_values = tf.squeeze(step_outputs.online_q[:-1], -1) # Flip target samples to have shape [S, T+1, B, ...] where 'S' is the # number of action samples taken. online_pi_samples = tf2_utils.batch_to_sequence( step_outputs.online_samples) target_q_online_pi_samples = snt.BatchApply( self._target_critic_network, 3)(tiled_observations, online_pi_samples) expected_q = tf.reduce_mean(tf.squeeze(target_q_online_pi_samples, -1), axis=0) # Flip online_log_probs to be of shape [S, T+1, B] and then compute # entropy by averaging over num samples. Final shape: [T+1, B]. online_log_probs = tf2_utils.batch_to_sequence( step_outputs.online_log_probs) sample_based_entropy = tf.reduce_mean(-online_log_probs, axis=0) retrace_outputs = continuous_retrace_ops.retrace_from_importance_weights( log_rhos=log_rhos, discounts=self._discount * discounts[:-1], rewards=rewards[:-1], q_values=target_q_values, values=value_function_target, bootstrap_value=bootstrap_value, lambda_=self._lambda, ) # Critic loss. Shape: [T, B]. critic_loss = 0.5 * tf.math.squared_difference( tf.stop_gradient(retrace_outputs.qs), online_q_values) # Policy loss- SVG0 with sample based entropy. Shape: [T, B] policy_loss = -(expected_q + self._entropy_regularizer_cost * sample_based_entropy) policy_loss = policy_loss[:-1] if self._prior_network is not None: # When training the prior, also add the per-timestep KL cost. policy_loss += (self._distillation_cost * step_outputs.analytic_kl_to_target[:-1]) # Ensure episode restarts are masked out when computing the losses. critic_loss = svg0_utils.mask_out_restarting( critic_loss, episode_start_mask) critic_loss = tf.reduce_mean(critic_loss) policy_loss = svg0_utils.mask_out_restarting( policy_loss, episode_start_mask) policy_loss = tf.reduce_mean(policy_loss) if self._prior_network is not None: prior_loss = step_outputs.analytic_kl_divergence[:-1] prior_loss = svg0_utils.mask_out_restarting( prior_loss, episode_start_mask) prior_loss = tf.reduce_mean(prior_loss) # Get trainable variables. policy_variables = self._policy_network.trainable_variables critic_variables = self._critic_network.trainable_variables # Compute gradients. policy_gradients = tape.gradient(policy_loss, policy_variables) critic_gradients = tape.gradient(critic_loss, critic_variables) if self._prior_network is not None: prior_variables = self._prior_network.trainable_variables prior_gradients = tape.gradient(prior_loss, prior_variables) # Delete the tape manually because of the persistent=True flag. del tape # Apply gradients. self._policy_optimizer.apply(policy_gradients, policy_variables) self._critic_optimizer.apply(critic_gradients, critic_variables) losses = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, } if self._prior_network is not None: self._prior_optimizer.apply(prior_gradients, prior_variables) losses['prior_loss'] = prior_loss # Losses to track. return losses
def _step(self) -> Dict[str, tf.Tensor]: # Draw a batch of data from replay. sample: reverb.ReplaySample = next(self._iterator) data = tf2_utils.batch_to_sequence(sample.data) observations, actions, rewards, discounts, extra = (data.observation, data.action, data.reward, data.discount, data.extras) unused_sequence_length, batch_size = actions.shape # Get initial state for the LSTM, either from replay or simply use zeros. if self._store_lstm_state: core_state = tree.map_structure(lambda x: x[0], extra['core_state']) else: core_state = self._network.initial_state(batch_size) target_core_state = tree.map_structure(tf.identity, core_state) # Before training, optionally unroll the LSTM for a fixed warmup period. burn_in_obs = tree.map_structure(lambda x: x[:self._burn_in_length], observations) _, core_state = self._burn_in(burn_in_obs, core_state) _, target_core_state = self._burn_in(burn_in_obs, target_core_state) # Don't train on the warmup period. observations, actions, rewards, discounts, extra = tree.map_structure( lambda x: x[self._burn_in_length:], (observations, actions, rewards, discounts, extra)) with tf.GradientTape() as tape: # Unroll the online and target Q-networks on the sequences. q_values, _ = self._network.unroll(observations, core_state, self._sequence_length) target_q_values, _ = self._target_network.unroll(observations, target_core_state, self._sequence_length) # Compute the target policy distribution (greedy). greedy_actions = tf.argmax(q_values, output_type=tf.int32, axis=-1) target_policy_probs = tf.one_hot( greedy_actions, depth=self._num_actions, dtype=q_values.dtype) # Compute the transformed n-step loss. rewards = tree.map_structure(lambda x: x[:-1], rewards) discounts = tree.map_structure(lambda x: x[:-1], discounts) loss, extra = losses.transformed_n_step_loss( qs=q_values, targnet_qs=target_q_values, actions=actions, rewards=rewards, pcontinues=discounts * self._discount, target_policy_probs=target_policy_probs, bootstrap_n=self._n_step, ) # Calculate importance weights and use them to scale the loss. sample_info = sample.info keys, probs = sample_info.key, sample_info.probability probs = tf2_utils.batch_to_sequence(probs) importance_weights = 1. / (self._max_replay_size * probs) # [T, B] importance_weights **= self._importance_sampling_exponent importance_weights /= tf.reduce_max(importance_weights) loss *= tf.cast(importance_weights, tf.float32) # [T, B] loss = tf.reduce_mean(loss) # [] # Apply gradients via optimizer. gradients = tape.gradient(loss, self._network.trainable_variables) # Clip and apply gradients. if self._clip_grad_norm is not None: gradients, _ = tf.clip_by_global_norm(gradients, self._clip_grad_norm) self._optimizer.apply(gradients, self._network.trainable_variables) # Periodically update the target network. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(self._network.variables, self._target_network.variables): dest.assign(src) self._num_steps.assign_add(1) if self._reverb_client: # Compute updated priorities. priorities = compute_priority(extra.errors, self._max_priority_weight) # Compute priorities and add an op to update them on the reverb side. self._reverb_client.update_priorities( table=adders.DEFAULT_PRIORITY_TABLE, keys=keys[:, 0], priorities=tf.cast(priorities, tf.float64)) return {'loss': loss}
def _forward(self, inputs: Any) -> None: """Trainer forward pass Args: inputs (Any): input data from the data table (transitions) """ # TODO: Update this forward function to work like MAD4PG data = inputs.data # Note (dries): The unused variable is start_of_episodes. observations, actions, rewards, discounts, _, extras = ( data.observations, data.actions, data.rewards, data.discounts, data.start_of_episode, data.extras, ) # Get initial state for the LSTM from replay and # extract the first state in the sequence.. core_state = tree.map_structure(lambda s: s[:, 0, :], extras["core_states"]) target_core_state = tree.map_structure(tf.identity, core_state) # TODO (dries): Take out all the data_points that does not need # to be processed here at the start. Therefore it does not have # to be done later on and saves processing time. self.policy_losses: Dict[str, tf.Tensor] = {} self.critic_losses: Dict[str, tf.Tensor] = {} # Do forward passes through the networks and calculate the losses with tf.GradientTape(persistent=True) as tape: # Note (dries): We are assuming that only the policy network # is recurrent and not the observation network. obs_trans, target_obs_trans = self._transform_observations( observations) target_actions = self._target_policy_actions( target_obs_trans, target_core_state) for agent in self._agents: agent_key = self.agent_net_keys[agent] # Get critic feed ( obs_trans_feed, target_obs_trans_feed, action_feed, target_actions_feed, ) = self._get_critic_feed( obs_trans=obs_trans, target_obs_trans=target_obs_trans, actions=actions, target_actions=target_actions, extras=extras, agent=agent, ) # Critic learning. # Remove the last sequence step for the normal network obs_comb, dims = train_utils.combine_dim(obs_trans_feed) act_comb, _ = train_utils.combine_dim(action_feed) q_values = self._critic_networks[agent_key](obs_comb, act_comb) q_values.set_dimensions(dims) # Remove first sequence step for the target obs_comb, _ = train_utils.combine_dim(target_obs_trans_feed) act_comb, _ = train_utils.combine_dim(target_actions_feed) target_q_values = self._target_critic_networks[agent_key]( obs_comb, act_comb) target_q_values.set_dimensions(dims) # Cast the additional discount to match # the environment discount dtype. agent_discount = discounts[agent] discount = tf.cast(self._discount, dtype=agent_discount.dtype) # Critic loss. critic_loss = recurrent_n_step_critic_loss( q_values, target_q_values, rewards[agent], discount * agent_discount, bootstrap_n=self._bootstrap_n, loss_fn=losses.categorical, ) self.critic_losses[agent] = tf.reduce_mean(critic_loss, axis=0) # Actor learning. obs_agent_feed = target_obs_trans[agent] # TODO (dries): Why is there an extra tuple? agent_core_state = core_state[agent][0] transposed_obs = tf2_utils.batch_to_sequence(obs_agent_feed) outputs, updated_states = snt.static_unroll( self._policy_networks[agent_key], transposed_obs, agent_core_state, ) dpg_actions = tf2_utils.batch_to_sequence(outputs) # Note (dries): This is done to so that losses.dpg can verify # using gradient.tape that there is a # gradient relationship between dpg_q_values and dpg_actions_comb. dpg_actions_comb, dim = train_utils.combine_dim(dpg_actions) # Note (dries): This seemingly useless line is important! # Don't remove it. See above note. dpg_actions = train_utils.extract_dim(dpg_actions_comb, dim) # Get dpg actions dpg_actions_feed = self._get_dpg_feed(target_actions, dpg_actions, agent) # Get dpg Q values. obs_comb, _ = train_utils.combine_dim(target_obs_trans_feed) act_comb, _ = train_utils.combine_dim(dpg_actions_feed) dpg_z_values = self._critic_networks[agent_key](obs_comb, act_comb) dpg_q_values = dpg_z_values.mean() # Actor loss. If clipping is true use dqda clipping and clip the norm. dqda_clipping = 1.0 if self._max_gradient_norm is not None else None clip_norm = True if self._max_gradient_norm is not None else False policy_loss = losses.dpg( dpg_q_values, dpg_actions_comb, tape=tape, dqda_clipping=dqda_clipping, clip_norm=clip_norm, ) self.policy_losses[agent] = tf.reduce_mean(policy_loss, axis=0) self.tape = tape
def _step(self) -> Dict[str, tf.Tensor]: """Does an SGD step on a batch of sequences.""" # Retrieve a batch of data from replay. inputs: reverb.ReplaySample = next(self._iterator) data = tf2_utils.batch_to_sequence(inputs.data) observations, actions, rewards, discounts, extra = (data.observation, data.action, data.reward, data.discount, data.extras) core_state = tree.map_structure(lambda s: s[0], extra['core_state']) # actions = actions[:-1] # [T-1] rewards = rewards[:-1] # [T-1] discounts = discounts[:-1] # [T-1] with tf.GradientTape() as tape: # Unroll current policy over observations. (logits, values), _ = snt.static_unroll(self._network, observations, core_state) # Compute importance sampling weights: current policy / behavior policy. behaviour_logits = extra['logits'] pi_behaviour = tfd.Categorical(logits=behaviour_logits[:-1]) pi_target = tfd.Categorical(logits=logits[:-1]) log_rhos = pi_target.log_prob(actions) - pi_behaviour.log_prob( actions) # Optionally clip rewards. rewards = tf.clip_by_value( rewards, tf.cast(-self._max_abs_reward, rewards.dtype), tf.cast(self._max_abs_reward, rewards.dtype)) # Critic loss. vtrace_returns = trfl.vtrace_from_importance_weights( log_rhos=tf.cast(log_rhos, tf.float32), discounts=tf.cast(self._discount * discounts, tf.float32), rewards=tf.cast(rewards, tf.float32), values=tf.cast(values[:-1], tf.float32), bootstrap_value=values[-1], ) critic_loss = tf.square(vtrace_returns.vs - values[:-1]) # Policy-gradient loss. policy_gradient_loss = trfl.policy_gradient( policies=pi_target, actions=actions, action_values=vtrace_returns.pg_advantages, ) # Entropy regulariser. entropy_loss = trfl.policy_entropy_loss(pi_target).loss # Combine weighted sum of actor & critic losses. loss = tf.reduce_mean(policy_gradient_loss + self._baseline_cost * critic_loss + self._entropy_cost * entropy_loss) # Compute gradients and optionally apply clipping. gradients = tape.gradient(loss, self._network.trainable_variables) gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) self._optimizer.apply(gradients, self._network.trainable_variables) metrics = { 'loss': loss, 'critic_loss': tf.reduce_mean(critic_loss), 'entropy_loss': tf.reduce_mean(entropy_loss), 'policy_gradient_loss': tf.reduce_mean(policy_gradient_loss), } return metrics
def _forward(self, inputs: Any) -> None: data = tree.map_structure( lambda v: tf.expand_dims(v, axis=0) if len(v.shape) <= 1 else v, inputs.data) data = tf2_utils.batch_to_sequence(data) observations, actions, rewards, discounts, _, extra = data core_state = tree.map_structure(lambda s: s[:, 0, :], inputs.data.extras["core_states"]) core_message = tree.map_structure(lambda s: s[:, 0, :], inputs.data.extras["core_messages"]) T = actions[self._agents[0]].shape[0] # Use fact that end of episode always has the reward to # find episode lengths. This is used to mask loss. ep_end = tf.argmax(tf.math.abs(rewards[self._agents[0]]), axis=0) with tf.GradientTape(persistent=True) as tape: q_network_losses: Dict[str, NestedArray] = { agent: { "q_value_loss": tf.zeros(()) } for agent in self._agents } state = {agent: core_state[agent][0] for agent in self._agents} target_state = { agent: core_state[agent][0] for agent in self._agents } message = {agent: core_message[agent][0] for agent in self._agents} target_message = { agent: core_message[agent][0] for agent in self._agents } # _target_q_networks must be 1 step ahead target_channel = self._communication_module.process_messages( target_message) for agent in self._agents: agent_key = self.agent_net_keys[agent] (q_targ, m), s = self._target_q_networks[agent_key]( observations[agent].observation[0], target_state[agent], target_channel[agent], ) target_state[agent] = s target_message[agent] = m for t in range(1, T, 1): channel = self._communication_module.process_messages(message) target_channel = self._communication_module.process_messages( target_message) for agent in self._agents: agent_key = self.agent_net_keys[agent] # Cast the additional discount # to match the environment discount dtype. discount = tf.cast(self._discount, dtype=discounts[agent][0].dtype) (q_targ, m), s = self._target_q_networks[agent_key]( observations[agent].observation[t], target_state[agent], target_channel[agent], ) target_state[agent] = s target_message[agent] = tf.math.multiply( m, observations[agent].observation[t][:, :1]) (q, m), s = self._q_networks[agent_key]( observations[agent].observation[t - 1], state[agent], channel[agent], ) state[agent] = s message[agent] = tf.math.multiply( m, observations[agent].observation[t - 1][:, :1]) # Mask target q_targ = tf.concat( [[q_targ[i]] if t <= ep_end[i] else [tf.zeros_like(q_targ[i])] for i in range(q_targ.shape[0])], axis=0, ) loss, _ = trfl.qlearning( q, actions[agent][t - 1], rewards[agent][t - 1], discount * discounts[agent][t], q_targ, ) # Index loss (mask ended episodes) if not tf.reduce_any(t - 1 <= ep_end): continue loss = tf.reduce_mean(loss[t - 1 <= ep_end]) # loss = tf.reduce_mean(loss) q_network_losses[agent]["q_value_loss"] += loss self._q_network_losses = q_network_losses self.tape = tape
def _forward(self, inputs: Any) -> None: """Trainer forward pass Args: inputs (Any): input data from the data table (transitions) """ # Convert to sequence data data = tf2_utils.batch_to_sequence(inputs.data) # Unpack input data as follows: observations, actions, rewards, discounts, extras = ( data.observations, data.actions, data.rewards, data.discounts, data.extras, ) # transform observation using observation networks observations_trans = self._transform_observations(observations) # Get log_probs. log_probs = extras["log_probs"] # Store losses. policy_losses: Dict[str, Any] = {} critic_losses: Dict[str, Any] = {} with tf.GradientTape(persistent=True) as tape: for agent in self._agents: action, reward, discount, behaviour_log_prob = ( actions[agent], rewards[agent], discounts[agent], log_probs[agent], ) actor_observation = observations_trans[agent] critic_observation = self._get_critic_feed( observations_trans, agent) # Chop off final timestep for bootstrapping value reward = reward[:-1] discount = discount[:-1] # Get agent network agent_key = agent.split( "_")[0] if self._shared_weights else agent policy_network = self._policy_networks[agent_key] critic_network = self._critic_networks[agent_key] # Reshape inputs. dims = actor_observation.shape[:2] actor_observation = snt.merge_leading_dims(actor_observation, num_dims=2) critic_observation = snt.merge_leading_dims(critic_observation, num_dims=2) policy = policy_network(actor_observation) values = critic_network(critic_observation) # Reshape the outputs. policy = tfd.BatchReshape(policy, batch_shape=dims, name="policy") values = tf.reshape(values, dims, name="value") # Values along the sequence T. bootstrap_value = values[-1] state_values = values[:-1] # Generalized Return Estimation td_loss, td_lambda_extra = trfl.td_lambda( state_values=state_values, rewards=reward, pcontinues=discount, bootstrap_value=bootstrap_value, lambda_=self._lambda_gae, name="CriticLoss", ) # Do not use the loss provided by td_lambda as they sum the losses over # the sequence length rather than averaging them. critic_loss = self._baseline_cost * tf.reduce_mean( tf.square(td_lambda_extra.temporal_differences), name="CriticLoss") # Compute importance sampling weights: current policy / behavior policy. log_rhos = policy.log_prob(action) - behaviour_log_prob importance_ratio = tf.exp(log_rhos)[:-1] clipped_importance_ratio = tf.clip_by_value( importance_ratio, 1.0 - self._clipping_epsilon, 1.0 + self._clipping_epsilon, ) # Generalized Advantage Estimation gae = tf.stop_gradient(td_lambda_extra.temporal_differences) mean, variance = tf.nn.moments(gae, axes=[0, 1], keepdims=True) normalized_gae = (gae - mean) / tf.sqrt(variance) policy_gradient_loss = tf.reduce_mean( -tf.minimum( tf.multiply(importance_ratio, normalized_gae), tf.multiply(clipped_importance_ratio, normalized_gae), ), name="PolicyGradientLoss", ) # Entropy regularization. Only implemented for categorical dist. try: policy_entropy = tf.reduce_mean(policy.entropy()) except NotImplementedError: policy_entropy = tf.convert_to_tensor(0.0) entropy_loss = -self._entropy_cost * policy_entropy # Combine weighted sum of actor & entropy regularization. policy_loss = policy_gradient_loss + entropy_loss policy_losses[agent] = policy_loss critic_losses[agent] = critic_loss self.policy_losses = policy_losses self.critic_losses = critic_losses self.tape = tape