def call(self, trajectory: Trajectory) -> Trajectory: time_step = TimeStep(trajectory.step_type, trajectory.reward, trajectory.discount, trajectory.observation) action_dist = self._policy.distribution(time_step).action # If the action distribution is in fact a tuple of distributions (one for each resource set) # then we need to index into them to attain the underlying distribution which can then be # used to attain probabilities. This is only the case where there are multiple resource # sets. for i in self._action_indices[:-1]: action_dist = action_dist[i] action_probs = action_dist.probs_parameter() # Zero out batch indices where a new episode is starting. self._probability_accumulator.assign( tf.where(trajectory.is_first(), tf.zeros_like(self._probability_accumulator), self._probability_accumulator)) self._count_accumulator.assign( tf.where(trajectory.is_first(), tf.zeros_like(self._count_accumulator), self._count_accumulator)) # Update accumulators with probability and count increments. self._probability_accumulator.assign_add(action_probs[..., 0, self._action_indices[-1]]) self._count_accumulator.assign_add(tf.ones_like(self._count_accumulator)) # Add final cumulants to buffer at the end of episodes. last_episode_indices = tf.squeeze(tf.where(trajectory.is_last()), axis=-1) for idx in last_episode_indices: self._buffer.add(self._probability_accumulator[idx] / self._count_accumulator[idx]) return trajectory
def call(self, trajectory: Trajectory) -> Trajectory: """ Process the experience passed in to update the metric value (or the components required to calculate the final value). :param trajectory: Experience from the agent rolling out in the environment. :return: The unchanged input trajectory (as per the standard use of TensorFlow Metrics). """ start_of_episode_indices = tf.squeeze(tf.where(trajectory.is_first()), axis=-1) mask = tf.ones(shape=(self._batch_size, ), dtype=self._dtype) for idx in start_of_episode_indices: mask -= tf.eye(self._batch_size)[idx] # Reset the accumulators at the end of each episode. self._num_valid_timesteps.assign(self._num_valid_timesteps * mask) self._activity_accumulator.assign(self._activity_accumulator * mask) # Find the number of time steps satisfying the filter condition. # The reshape is to ensure compatibility with the variable below in the case of no batch # dimension. valid_timesteps = tf.reshape( tf.reduce_sum(tf.cast(self.filter_condition(trajectory), self._dtype), axis=-1), self._num_valid_timesteps.shape) # Track the number of time steps which meet the qualifying condition. self._num_valid_timesteps.assign_add(valid_timesteps, name="increment_valid_timesteps") # Update accumulator with activity counts where both the filtering and activity condition # are satisfied. Again the reshape is to ensure compatibility with the accumulator # variable in the case where there is no batch dimension. bool_values = tf.logical_and(self.filter_condition(trajectory), self.activity_condition(trajectory)) to_add = tf.reshape( tf.reduce_sum(tf.cast(bool_values, self._dtype), axis=-1), self._activity_accumulator.shape) self._activity_accumulator.assign_add(to_add) # Add values to buffer at the end of the episode by first finding where the trajectories end # and then using the resulting indices to update the correct buffer locations. # At the same time build up a mask of values to use for resetting the accumulators. end_of_episode_indices = tf.squeeze( tf.where(trajectory.step_type == 2), axis=-1) for idx in end_of_episode_indices: self._activity_buffer.add(self._activity_accumulator[idx]) self._qualifying_timesteps_buffer.add( self._num_valid_timesteps[idx]) # Return the original trajectory data as is standard for TFStepMetrics. return trajectory