def unroll( self, inputs: types.NestedTensor, state: base.State, sequence_length: int, ) -> Tuple[types.NestedTensor, base.State]: return snt.static_unroll(self, inputs, state, sequence_length)
def unroll( self, inputs: observation_action_reward.OAR, state: snt.LSTMState, sequence_length: int, ) -> Tuple[QValues, snt.LSTMState]: """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" embeddings = snt.BatchApply(self._embed)(inputs) # [T, B, D+A+1] embeddings, new_state = snt.static_unroll(self._core, embeddings, state, sequence_length) action_values = snt.BatchApply(self._head)(embeddings) return action_values, new_state
def unroll(self, inputs, state, sequence_length): return snt.static_unroll(self._net, inputs, state, sequence_length)
def _forward(self, inputs: Any) -> None: """Trainer forward pass Args: inputs (Any): input data from the data table (transitions) """ # TODO: Update this forward function to work like MAD4PG data = inputs.data # Note (dries): The unused variable is start_of_episodes. observations, actions, rewards, discounts, _, extras = ( data.observations, data.actions, data.rewards, data.discounts, data.start_of_episode, data.extras, ) # Get initial state for the LSTM from replay and # extract the first state in the sequence.. core_state = tree.map_structure(lambda s: s[:, 0, :], extras["core_states"]) target_core_state = tree.map_structure(tf.identity, core_state) # TODO (dries): Take out all the data_points that does not need # to be processed here at the start. Therefore it does not have # to be done later on and saves processing time. self.policy_losses: Dict[str, tf.Tensor] = {} self.critic_losses: Dict[str, tf.Tensor] = {} # Do forward passes through the networks and calculate the losses with tf.GradientTape(persistent=True) as tape: # Note (dries): We are assuming that only the policy network # is recurrent and not the observation network. obs_trans, target_obs_trans = self._transform_observations( observations) target_actions = self._target_policy_actions( target_obs_trans, target_core_state) for agent in self._agents: agent_key = self.agent_net_keys[agent] # Get critic feed ( obs_trans_feed, target_obs_trans_feed, action_feed, target_actions_feed, ) = self._get_critic_feed( obs_trans=obs_trans, target_obs_trans=target_obs_trans, actions=actions, target_actions=target_actions, extras=extras, agent=agent, ) # Critic learning. # Remove the last sequence step for the normal network obs_comb, dims = train_utils.combine_dim(obs_trans_feed) act_comb, _ = train_utils.combine_dim(action_feed) q_values = self._critic_networks[agent_key](obs_comb, act_comb) q_values.set_dimensions(dims) # Remove first sequence step for the target obs_comb, _ = train_utils.combine_dim(target_obs_trans_feed) act_comb, _ = train_utils.combine_dim(target_actions_feed) target_q_values = self._target_critic_networks[agent_key]( obs_comb, act_comb) target_q_values.set_dimensions(dims) # Cast the additional discount to match # the environment discount dtype. agent_discount = discounts[agent] discount = tf.cast(self._discount, dtype=agent_discount.dtype) # Critic loss. critic_loss = recurrent_n_step_critic_loss( q_values, target_q_values, rewards[agent], discount * agent_discount, bootstrap_n=self._bootstrap_n, loss_fn=losses.categorical, ) self.critic_losses[agent] = tf.reduce_mean(critic_loss, axis=0) # Actor learning. obs_agent_feed = target_obs_trans[agent] # TODO (dries): Why is there an extra tuple? agent_core_state = core_state[agent][0] transposed_obs = tf2_utils.batch_to_sequence(obs_agent_feed) outputs, updated_states = snt.static_unroll( self._policy_networks[agent_key], transposed_obs, agent_core_state, ) dpg_actions = tf2_utils.batch_to_sequence(outputs) # Note (dries): This is done to so that losses.dpg can verify # using gradient.tape that there is a # gradient relationship between dpg_q_values and dpg_actions_comb. dpg_actions_comb, dim = train_utils.combine_dim(dpg_actions) # Note (dries): This seemingly useless line is important! # Don't remove it. See above note. dpg_actions = train_utils.extract_dim(dpg_actions_comb, dim) # Get dpg actions dpg_actions_feed = self._get_dpg_feed(target_actions, dpg_actions, agent) # Get dpg Q values. obs_comb, _ = train_utils.combine_dim(target_obs_trans_feed) act_comb, _ = train_utils.combine_dim(dpg_actions_feed) dpg_z_values = self._critic_networks[agent_key](obs_comb, act_comb) dpg_q_values = dpg_z_values.mean() # Actor loss. If clipping is true use dqda clipping and clip the norm. dqda_clipping = 1.0 if self._max_gradient_norm is not None else None clip_norm = True if self._max_gradient_norm is not None else False policy_loss = losses.dpg( dpg_q_values, dpg_actions_comb, tape=tape, dqda_clipping=dqda_clipping, clip_norm=clip_norm, ) self.policy_losses[agent] = tf.reduce_mean(policy_loss, axis=0) self.tape = tape
def _step(self) -> Dict[str, tf.Tensor]: """Does an SGD step on a batch of sequences.""" # Retrieve a batch of data from replay. inputs: reverb.ReplaySample = next(self._iterator) data = tf2_utils.batch_to_sequence(inputs.data) observations, actions, rewards, discounts, extra = (data.observation, data.action, data.reward, data.discount, data.extras) core_state = tree.map_structure(lambda s: s[0], extra['core_state']) # actions = actions[:-1] # [T-1] rewards = rewards[:-1] # [T-1] discounts = discounts[:-1] # [T-1] with tf.GradientTape() as tape: # Unroll current policy over observations. (logits, values), _ = snt.static_unroll(self._network, observations, core_state) # Compute importance sampling weights: current policy / behavior policy. behaviour_logits = extra['logits'] pi_behaviour = tfd.Categorical(logits=behaviour_logits[:-1]) pi_target = tfd.Categorical(logits=logits[:-1]) log_rhos = pi_target.log_prob(actions) - pi_behaviour.log_prob( actions) # Optionally clip rewards. rewards = tf.clip_by_value( rewards, tf.cast(-self._max_abs_reward, rewards.dtype), tf.cast(self._max_abs_reward, rewards.dtype)) # Critic loss. vtrace_returns = trfl.vtrace_from_importance_weights( log_rhos=tf.cast(log_rhos, tf.float32), discounts=tf.cast(self._discount * discounts, tf.float32), rewards=tf.cast(rewards, tf.float32), values=tf.cast(values[:-1], tf.float32), bootstrap_value=values[-1], ) critic_loss = tf.square(vtrace_returns.vs - values[:-1]) # Policy-gradient loss. policy_gradient_loss = trfl.policy_gradient( policies=pi_target, actions=actions, action_values=vtrace_returns.pg_advantages, ) # Entropy regulariser. entropy_loss = trfl.policy_entropy_loss(pi_target).loss # Combine weighted sum of actor & critic losses. loss = tf.reduce_mean(policy_gradient_loss + self._baseline_cost * critic_loss + self._entropy_cost * entropy_loss) # Compute gradients and optionally apply clipping. gradients = tape.gradient(loss, self._network.trainable_variables) gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) self._optimizer.apply(gradients, self._network.trainable_variables) metrics = { 'loss': loss, 'critic_loss': tf.reduce_mean(critic_loss), 'entropy_loss': tf.reduce_mean(entropy_loss), 'policy_gradient_loss': tf.reduce_mean(policy_gradient_loss), } return metrics
def _step(self, data: Step) -> Dict[str, tf.Tensor]: """Does an SGD step on a batch of sequences.""" observations, actions, rewards, discounts, _, extra = data core_state = tree.map_structure(lambda s: s[0], extra['core_state']) actions = actions[:-1] # [T-1] rewards = rewards[:-1] # [T-1] discounts = discounts[:-1] # [T-1] # Workaround for NO_OP actions # In some environments, passing NO_OP(-1) actions would lead to a crash. # These actions (at episode boundaries) should be ignored anyway, # so we replace NO_OP actions with a valid action index (0). actions = (tf.zeros_like(actions) * tf.cast(actions == -1, tf.int32) + actions * tf.cast(actions != -1, tf.int32)) with tf.GradientTape() as tape: # Unroll current policy over observations. (logits, values), _ = snt.static_unroll(self._network, observations, core_state) # TODO: mask policy here as well. # Masked policy. #masked_eligibility = observations['mask'] * observations['eligibility'] #out_logits = logits - logits.min(axis=-1, keepdims=True) #out_logits[masked_eligibility == 0] = -np.infty #out_logits -= out_logits.max(axis=-1, keepdims=True) pi = tfd.Categorical(logits=logits[:-1]) # Optionally clip rewards. rewards = tf.clip_by_value( rewards, tf.cast(-self._max_abs_reward, rewards.dtype), tf.cast(self._max_abs_reward, rewards.dtype)) # Compute returns (optionally, GAE) discounted_returns = trfl.generalized_lambda_returns( rewards=tf.cast(rewards, tf.float32), pcontinues=tf.cast(self._discount * discounts, tf.float32), values=tf.cast(values[:-1], tf.float32), bootstrap_value=tf.cast(values[-1], tf.float32), lambda_=self._gae_lambda) advantages = discounted_returns - values[:-1] # Compute actor & critic losses. critic_loss = tf.square(advantages) #policy_gradient_loss = trfl.policy_gradient( # policies=pi, # actions=actions, # action_values=advantages #) # TODO: Remove later. action_values = advantages policy_vars = None policy_vars = list(policy_vars) if policy_vars else list() with tf1.name_scope(values=policy_vars + [actions, action_values], name="policy_gradient"): actions = tf1.stop_gradient(actions) action_values = tf1.stop_gradient(action_values) log_prob_actions = pi.log_prob(actions) # Prevent accidental broadcasting if possible at construction time. action_values.get_shape().assert_is_compatible_with( log_prob_actions.get_shape()) policy_gradient_loss = -tf1.multiply(log_prob_actions, action_values) #entropy_loss = trfl.policy_entropy_loss(pi).loss # TODO: Remove later. entropy_info = trfl.policy_entropy_loss(pi) entropy_loss = entropy_info.loss loss = tf.reduce_mean(policy_gradient_loss + self._baseline_cost * critic_loss + self._entropy_cost * entropy_loss) # Compute gradients and optionally apply clipping. gradients = tape.gradient(loss, self._network.trainable_variables) gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) self._optimizer.apply(gradients, self._network.trainable_variables) metrics = { 'loss': loss, 'critic_loss': tf.reduce_mean(critic_loss), 'entropy_loss': tf.reduce_mean(entropy_loss), 'policy_gradient_loss': tf.reduce_mean(policy_gradient_loss), 'log_probs': tf.reduce_mean(log_prob_actions), 'learning_rate': self._learning_rate, 'advantages': tf.reduce_mean(advantages), 'discounted_returns': tf.reduce_mean(discounted_returns), 'entropy': tf.reduce_mean(pi.entropy()) } return metrics, gradients, logits
def _step(self, data: Step) -> Dict[str, tf.Tensor]: """Does an SGD step on a batch of sequences.""" observations, actions, rewards, discounts, _, extra = data core_state = tree.map_structure(lambda s: s[0], extra['core_state']) actions = actions[:-1] # [T-1] rewards = rewards[:-1] # [T-1] discounts = discounts[:-1] # [T-1] # Workaround for NO_OP actions # In some environments, passing NO_OP(-1) actions would lead to a crash. # These actions (at episode boundaries) should be ignored anyway, # so we replace NO_OP actions with a valid action index (0). actions = (tf.zeros_like(actions) * tf.cast(actions == -1, tf.int32) + actions * tf.cast(actions != -1, tf.int32)) with tf.GradientTape() as tape: # Unroll current policy over observations. (logits, values), _ = snt.static_unroll(self._network, observations, core_state) pi = tfd.Categorical(logits=logits[:-1]) # Optionally clip rewards. rewards = tf.clip_by_value(rewards, tf.cast(-self._max_abs_reward, rewards.dtype), tf.cast(self._max_abs_reward, rewards.dtype)) # Compute actor & critic losses. discounted_returns = trfl.generalized_lambda_returns( rewards=tf.cast(rewards, tf.float32), pcontinues=tf.cast(self._discount*discounts, tf.float32), values=tf.cast(values[:-1], tf.float32), bootstrap_value=tf.cast(values[-1], tf.float32) ) advantages = discounted_returns - values[:-1] critic_loss = tf.square(advantages) policy_gradient_loss = trfl.policy_gradient( policies=pi, actions=actions, action_values=advantages ) entropy_loss = trfl.policy_entropy_loss(pi).loss loss = tf.reduce_mean(policy_gradient_loss + self._baseline_cost * critic_loss + self._entropy_cost * entropy_loss) # Compute gradients and optionally apply clipping. gradients = tape.gradient(loss, self._network.trainable_variables) gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) self._optimizer.apply(gradients, self._network.trainable_variables) metrics = { 'loss': loss, 'critic_loss': tf.reduce_mean(critic_loss), 'entropy_loss': tf.reduce_mean(entropy_loss), 'policy_gradient_loss': tf.reduce_mean(policy_gradient_loss), } return metrics