def _eval_constraint_and_regs(self, dataset: dataset_lib.OffpolicyDataset, target_policy: tf_policy.TFPolicy): """Get the residual term and the primal and dual regularizers during eval. Args: dataset: The dataset to sample experience from. target_policy: The policy whose value we want to estimate. Returns: The residual term (weighted by zeta), primal, and dual reg values. """ experience = dataset.get_all_steps(num_steps=2) env_step = tf.nest.map_structure(lambda t: t[:, 0, ...], experience) next_env_step = tf.nest.map_structure(lambda t: t[:, 1, ...], experience) nu_values, _, _ = self._sample_value(self._nu_network, env_step) next_nu_values, _, _ = self._sample_average_value( self._nu_network, next_env_step, target_policy) zeta_values, neg_kl, _ = self._sample_value(self._zeta_network, env_step) discounts = self._gamma * env_step.discount bellman_residuals = ( common_lib.reverse_broadcast(discounts, nu_values) * next_nu_values - nu_values - self._norm_regularizer * self._lam) # Always include reward during eval bellman_residuals += self._reward_fn(env_step) constraint = tf.reduce_mean(zeta_values * bellman_residuals) f_nu = tf.reduce_mean(self._f_fn(nu_values)) f_zeta = tf.reduce_mean(self._f_fn(zeta_values)) return constraint, f_nu, f_zeta, tf.reduce_mean(neg_kl)
def get_fullbatch_average( dataset: OffpolicyDataset, limit: Optional[int] = None, by_steps: bool = True, truncate_episode_at: Optional[int] = None, reward_fn: Callable = None, weight_fn: Callable = None, gamma: Union[float, tf.Tensor] = 1.0) -> Union[float, tf.Tensor]: """Computes average reward over full dataset. Args: dataset: The dataset to sample experience from. limit: If specified, the maximum number of steps/episodes to take from the dataset. by_steps: Whether to sample batches of steps (default) or episodes. truncate_episode_at: If sampling by episodes, where to truncate episodes from the environment, if at all. reward_fn: A function that takes in an EnvStep and returns the reward for that step. If not specified, defaults to just EnvStep.reward. When sampling by episode, valid_steps is also passed into reward_fn. weight_fn: A function that takes in an EnvStep and returns a weight for that step. If not specified, defaults to gamma ** step_num. When sampling by episode, valid_steps is also passed into reward_fn. gamma: The discount factor to use for the default reward/weight functions. Returns: An estimate of the average reward. """ if reward_fn is None: if by_steps: reward_fn = _default_by_steps_reward_fn else: reward_fn = lambda *args: _default_by_episodes_reward_fn( *args, gamma=gamma) if weight_fn is None: if by_steps: weight_fn = lambda *args: _default_by_steps_weight_fn(*args, gamma=gamma) else: weight_fn = _default_by_episodes_weight_fn if by_steps: steps = dataset.get_all_steps(limit=limit) rewards = reward_fn(steps) weights = weight_fn(steps) else: episodes, valid_steps = dataset.get_all_episodes( truncate_episode_at=truncate_episode_at, limit=limit) rewards = reward_fn(episodes, valid_steps) weights = weight_fn(episodes, valid_steps) rewards = common_lib.reverse_broadcast(rewards, weights) weights = common_lib.reverse_broadcast(weights, rewards) if tf.rank(weights) < 2: return (tf.reduce_sum(rewards * weights, axis=0) / tf.reduce_sum(weights, axis=0)) return (tf.linalg.matmul(weights, rewards) / tf.reduce_sum(tf.math.reduce_mean(weights, axis=0)))