Example #1
0
    def train_loss(self, initial_env_step, env_step, next_env_step, policy):
        nu_values = self._get_value(self._nu_network, env_step)
        initial_nu_values = self._get_average_value(self._nu_network,
                                                    initial_env_step, policy)
        next_nu_values = self._get_average_value(self._nu_network,
                                                 next_env_step, policy)
        zeta_values = self._get_value(self._zeta_network, env_step)

        discounts = self._gamma * next_env_step.discount
        policy_ratio = 1.0
        if not self._solve_for_state_action_ratio:
            tfagents_step = dataset_lib.convert_to_tfagents_timestep(env_step)
            policy_log_probabilities = policy.distribution(
                tfagents_step).action.log_prob(env_step.action)
            policy_ratio = tf.exp(policy_log_probabilities -
                                  env_step.get_log_probability())

        bellman_residuals = (nu_values - common_lib.reverse_broadcast(
            discounts * policy_ratio, nu_values) * next_nu_values)

        zeta_loss = self._fstar_fn(
            zeta_values) - bellman_residuals * zeta_values
        if self._primal_form:
            nu_loss = (self._f_fn(bellman_residuals) -
                       (1 - self._gamma) * initial_nu_values)
        else:
            nu_loss = -zeta_loss - (1 - self._gamma) * initial_nu_values

        return nu_loss, zeta_loss
Example #2
0
  def _get_average_value(self, network, env_step, policy):
    if self._solve_for_state_action_ratio:
      tfagents_step = dataset_lib.convert_to_tfagents_timestep(env_step)
      if self._categorical_action and self._num_samples is None:
        action_weights = policy.distribution(
            tfagents_step).action.probs_parameter()
        action_dtype = self._dataset_spec.action.dtype
        batch_size = tf.shape(action_weights)[0]
        num_actions = tf.shape(action_weights)[-1]
        actions = (  # Broadcast actions
            tf.ones([batch_size, 1], dtype=action_dtype) *
            tf.range(num_actions, dtype=action_dtype)[None, :])
      else:
        batch_size = tf.shape(env_step.observation)[0]
        num_actions = self._num_samples
        action_weights = tf.ones([batch_size, num_actions]) / num_actions
        actions = tf.stack(
            [policy.action(tfagents_step).action for _ in range(num_actions)],
            axis=1)

      flat_actions = tf.reshape(actions, [batch_size * num_actions] +
                                actions.shape[2:].as_list())
      flat_observations = tf.reshape(
          tf.tile(env_step.observation[:, None, ...],
                  [1, num_actions] + [1] * len(env_step.observation.shape[1:])),
          [batch_size * num_actions] + env_step.observation.shape[1:].as_list())

      flat_values, _ = network((flat_observations, flat_actions))
      values = tf.reshape(flat_values, [batch_size, num_actions] +
                          flat_values.shape[1:].as_list())
      return tf.reduce_sum(
          values * common_lib.reverse_broadcast(action_weights, values), axis=1)
    else:
      return network((env_step.observation,))[0]
Example #3
0
        def weight_fn(env_step):
            if self._step_encoding is not None:
                zeta = 0.
                for step_num in range(self._max_trajectory_length):
                    index = self._get_index(env_step.observation,
                                            env_step.action, step_num)
                    zeta += self._gamma**step_num * self._zeta[index]
                zeta *= (1 - self._gamma) / (1 - self._gamma**
                                             (self._num_steps - 1))
            else:
                index = self._get_index(env_step.observation, env_step.action,
                                        env_step.step_num)
                zeta = self._zeta[index]
                zeta = tf.where(
                    env_step.step_num >= self._max_trajectory_length,
                    tf.zeros_like(zeta), zeta)

            policy_ratio = 1.0
            if not self._solve_for_state_action_ratio:
                tfagents_timestep = dataset_lib.convert_to_tfagents_timestep(
                    env_step)
                target_log_probabilities = target_policy.distribution(
                    tfagents_timestep).action.log_prob(env_step.action)
                policy_ratio = tf.exp(target_log_probabilities -
                                      env_step.get_log_probability())

            return tf.cast(zeta, dtype=tf.float32) * tf.cast(policy_ratio,
                                                             dtype=tf.float32)
Example #4
0
        def reward_fn(env_step, valid_steps, qvalues=self._point_qvalues):
            """Computes average initial Q-values of episodes."""
            # env_step is an episode, and we just want the first step.
            if tf.rank(valid_steps) == 1:
                first_step = tf.nest.map_structure(lambda t: t[0, ...],
                                                   env_step)
            else:
                first_step = tf.nest.map_structure(lambda t: t[:, 0, ...],
                                                   env_step)

            if self._solve_for_state_action_value:
                indices = self._get_index(
                    first_step.observation[:, None],
                    np.arange(self._num_actions)[None, :])
                initial_qvalues = tf.cast(tf.gather(qvalues, indices),
                                          tf.float32)

                tfagents_first_step = dataset_lib.convert_to_tfagents_timestep(
                    first_step)
                initial_target_probs = target_policy.distribution(
                    tfagents_first_step).action.probs_parameter()
                value = tf.reduce_sum(initial_qvalues * initial_target_probs,
                                      axis=-1)
            else:
                indices = self._get_index(first_step.observation,
                                          first_step.action)
                value = tf.cast(tf.gather(qvalues, indices), tf.float32)

            return value
Example #5
0
 def weight_fn(env_step):
   zeta = self._get_value(self._zeta_network, env_step)
   policy_ratio = 1.0
   if not self._solve_for_state_action_ratio:
     tfagents_timestep = dataset_lib.convert_to_tfagents_timestep(env_step)
     target_log_probabilities = target_policy.distribution(
         tfagents_timestep).action.log_prob(env_step.action)
     policy_ratio = tf.exp(target_log_probabilities -
                           env_step.get_log_probability())
   return zeta * common_lib.reverse_broadcast(policy_ratio, zeta)
Example #6
0
        def weight_fn(env_step):
            index = self._get_index(env_step.observation, env_step.action)
            zeta = self._zeta[index]

            policy_ratio = 1.0
            if not self._solve_for_state_action_ratio:
                tfagents_timestep = dataset_lib.convert_to_tfagents_timestep(
                    env_step)
                target_log_probabilities = target_policy.distribution(
                    tfagents_timestep).action.log_prob(env_step.action)
                policy_ratio = tf.exp(target_log_probabilities -
                                      env_step.get_log_probability())

            return tf.cast(zeta * policy_ratio, tf.float32)
        def weight_fn(env_step):
            index = self._get_index(env_step.observation, env_step.action)
            zeta = tf.gather(self._zeta,
                             tf.tile(index[None, :], [num_samples, 1]),
                             batch_dims=1)
            policy_ratio = 1.0
            if not self._solve_for_state_action_ratio:
                tfagents_timestep = dataset_lib.convert_to_tfagents_timestep(
                    env_step)
                target_log_probabilities = target_policy.distribution(
                    tfagents_timestep).action.log_prob(env_step.action)
                policy_ratio = tf.exp(target_log_probabilities -
                                      env_step.get_log_probability())

            return tf.cast(zeta * policy_ratio, tf.float32)
Example #8
0
    def train_loss(self, initial_env_step, env_step, next_env_step, policy):
        nu_values, _, eps = self._sample_value(self._nu_network, env_step)
        initial_nu_values, _, _ = self._sample_average_value(
            self._nu_network, initial_env_step, policy)
        next_nu_values, _, _ = self._sample_average_value(
            self._nu_network, next_env_step, policy)

        zeta_values, zeta_neg_kl, _ = self._sample_value(
            self._zeta_network, env_step, eps)

        discounts = self._gamma * env_step.discount
        policy_ratio = 1.0
        if not self._solve_for_state_action_ratio:
            tfagents_step = dataset_lib.convert_to_tfagents_timestep(env_step)
            policy_log_probabilities = policy.distribution(
                tfagents_step).action.log_prob(env_step.action)
            policy_ratio = tf.exp(policy_log_probabilities -
                                  env_step.get_log_probability())

        bellman_residuals = (
            common_lib.reverse_broadcast(discounts * policy_ratio, nu_values) *
            next_nu_values - nu_values - self._norm_regularizer * self._lam)
        if not self._zero_reward:
            bellman_residuals += policy_ratio * self._reward_fn(env_step)

        zeta_loss = -zeta_values * bellman_residuals
        nu_loss = (1 - self._gamma) * initial_nu_values
        lam_loss = self._norm_regularizer * self._lam
        if self._primal_form:
            nu_loss += self._fstar_fn(bellman_residuals)
            lam_loss = lam_loss + self._fstar_fn(bellman_residuals)
        else:
            nu_loss += zeta_values * bellman_residuals
            lam_loss = lam_loss - self._norm_regularizer * zeta_values * self._lam

        nu_loss += self._primal_regularizer * self._f_fn(nu_values)
        zeta_loss += self._dual_regularizer * self._f_fn(zeta_values)
        zeta_loss -= self._kl_regularizer * tf.reduce_mean(zeta_neg_kl)

        if self._weight_by_gamma:
            weights = self._gamma**tf.cast(env_step.step_num, tf.float32)[:,
                                                                          None]
            weights /= 1e-6 + tf.reduce_mean(weights)
            nu_loss *= weights
            zeta_loss *= weights

        return nu_loss, zeta_loss, lam_loss
    def _get_nu_loss(self, initial_env_step, env_step, next_env_step, policy):
        """Get nu_loss for both upper and lower confidence intervals."""
        nu_index = self._get_index(env_step.observation, env_step.action)
        nu_values = tf.gather(self._nu, nu_index)

        initial_nu_values = self._get_average_value(self._nu, initial_env_step,
                                                    policy)
        next_nu_values = self._get_average_value(self._nu, next_env_step,
                                                 policy)

        rewards = self._reward_fn(env_step)

        discounts = self._gamma * env_step.discount
        policy_ratio = 1.0

        if not self._solve_for_state_action_ratio:
            tfagents_step = dataset_lib.convert_to_tfagents_timestep(env_step)
            policy_log_probabilities = policy.distribution(
                tfagents_step).action.log_prob(env_step.action)
            policy_ratio = tf.exp(policy_log_probabilities -
                                  env_step.get_log_probability())

        bellman_residuals = (
            -nu_values + common_lib.reverse_broadcast(
                rewards, tf.convert_to_tensor(nu_values)) +
            common_lib.reverse_broadcast(discounts * policy_ratio,
                                         tf.convert_to_tensor(nu_values)) *
            next_nu_values)
        bellman_residuals *= self._algae_alpha_sign

        init_nu_loss = ((1 - self._gamma) * initial_nu_values *
                        self._algae_alpha_sign)

        nu_loss = (tf.math.abs(self._algae_alpha) * tf.math.square(
            bellman_residuals / tf.math.abs(self._algae_alpha)) / 2.0 +
                   init_nu_loss)

        if self._weight_by_gamma:
            weights = tf.expand_dims(self._gamma**tf.cast(
                env_step.step_num, tf.float32),
                                     axis=1)
            weights /= 1e-6 + tf.reduce_mean(weights)
            nu_loss *= weights

        return nu_loss
Example #10
0
  def train_loss(self, initial_env_step, env_step, next_env_step, policy):
    nu_values = self._get_value(self._nu_network, env_step)
    initial_nu_values = self._get_average_value(self._nu_network,
                                                initial_env_step, policy)
    next_nu_values = self._get_average_value(self._nu_network, next_env_step,
                                             policy)
    zeta_values = self._get_value(self._zeta_network, env_step)
    rewards = self._reward_fn(env_step)

    discounts = self._gamma * env_step.discount
    policy_ratio = 1.0
    if not self._solve_for_state_action_ratio:
      tfagents_step = dataset_lib.convert_to_tfagents_timestep(env_step)
      policy_log_probabilities = policy.distribution(
          tfagents_step).action.log_prob(env_step.action)
      policy_ratio = tf.exp(policy_log_probabilities -
                            env_step.get_log_probability())

    bellman_residuals = (
        -nu_values + common_lib.reverse_broadcast(rewards, nu_values) +
        common_lib.reverse_broadcast(discounts * policy_ratio, nu_values) *
        next_nu_values)
    bellman_residuals *= self._algae_alpha_sign
    #print(initial_nu_values, bellman_residuals)

    zeta_loss = (
        self._algae_alpha_abs * self._fstar_fn(zeta_values) -
        bellman_residuals * zeta_values)

    init_nu_loss = ((1 - self._gamma) * initial_nu_values *
                    self._algae_alpha_sign)
    if self._primal_form:
      nu_loss = (
          self._algae_alpha_abs *
          self._f_fn(bellman_residuals / self._algae_alpha_abs) + init_nu_loss)
    else:
      nu_loss = -zeta_loss + init_nu_loss

    if self._weight_by_gamma:
      weights = self._gamma**tf.cast(env_step.step_num, tf.float32)[:, None]
      weights /= 1e-6 + tf.reduce_mean(weights)
      nu_loss *= weights
      zeta_loss *= weights
    return nu_loss, zeta_loss
Example #11
0
    def _get_average_value(self, network, env_step, policy):
        if self._q_network is None:
            return tf.zeros_like(env_step.reward)

        tfagents_step = dataset_lib.convert_to_tfagents_timestep(env_step)
        if self._categorical_action and self._num_samples is None:
            action_weights = policy.distribution(
                tfagents_step).action.probs_parameter()
            action_dtype = self._dataset_spec.action.dtype
            batch_size = tf.shape(action_weights)[0]
            num_actions = tf.shape(action_weights)[-1]
            actions = (  # Broadcast actions
                tf.ones([batch_size, 1], dtype=action_dtype) *
                tf.range(num_actions, dtype=action_dtype)[None, :])
        else:
            batch_size = tf.shape(env_step.observation)[0]
            num_actions = self._num_samples
            action_weights = tf.ones([batch_size, num_actions]) / num_actions
            actions = tf.stack([
                policy.action(tfagents_step).action for _ in range(num_actions)
            ],
                               axis=1)

        flat_actions = tf.reshape(
            actions,
            tf.concat([[batch_size * num_actions],
                       tf.shape(actions)[2:]],
                      axis=0))
        flat_observations = tf.reshape(
            tf.tile(env_step.observation[:, None, ...], [1, num_actions] +
                    [1] * len(env_step.observation.shape[1:])),
            tf.concat([[batch_size * num_actions],
                       tf.shape(env_step.observation)[1:]],
                      axis=0))
        flat_values, _ = network((flat_observations, flat_actions))

        values = tf.reshape(
            flat_values,
            tf.concat([[batch_size, num_actions],
                       tf.shape(flat_values)[1:]],
                      axis=0))
        return tf.reduce_sum(values * action_weights, axis=1)
Example #12
0
    def train_loss(self, env_step, rewards, next_env_step, policy, gamma):
        values = self._get_value(env_step)
        discounts = gamma * next_env_step.discount
        target_values = self._get_target_value(next_env_step, policy)
        #target_values = tf.reduce_min(target_values, axis=-1, keepdims=True)

        if self._num_qvalues is not None and tf.rank(discounts) == 1:
            discounts = discounts[:, None]
        td_targets = rewards + discounts * tf.stop_gradient(target_values)

        policy_ratio = 1.0
        if not self._solve_for_state_action_value:
            tfagents_step = dataset_lib.convert_to_tfagents_timestep(env_step)
            policy_log_probabilities = policy.distribution(
                tfagents_step).action.log_prob(env_step.action)
            policy_ratio = tf.exp(policy_log_probabilities -
                                  env_step.get_log_probability())

        td_errors = policy_ratio * td_targets - values
        return tf.square(td_errors)
Example #13
0
    def solve(self,
              dataset: dataset_lib.OffpolicyDataset,
              target_policy: tf_policy.TFPolicy,
              regularizer: float = 1e-8):
        """Solves for density ratios and then approximates target policy value.

    Args:
      dataset: The dataset to sample experience from.
      target_policy: The policy whose value we want to estimate.
      regularizer: A small constant to add to matrices before inverting them or
        to floats before taking square root.

    Returns:
      Estimated average per-step reward of the target policy.
    """
        td_residuals = np.zeros([self._dimension, self._dimension])
        total_weights = np.zeros([self._dimension])
        initial_weights = np.zeros([self._dimension])

        episodes, valid_steps = dataset.get_all_episodes(limit=None)
        tfagents_episodes = dataset_lib.convert_to_tfagents_timestep(episodes)

        for episode_num in range(tf.shape(valid_steps)[0]):
            # Precompute probabilites for this episode.
            this_episode = tf.nest.map_structure(lambda t: t[episode_num],
                                                 episodes)
            first_step = tf.nest.map_structure(lambda t: t[0], this_episode)
            this_tfagents_episode = dataset_lib.convert_to_tfagents_timestep(
                this_episode)
            episode_target_log_probabilities = target_policy.distribution(
                this_tfagents_episode).action.log_prob(this_episode.action)
            episode_target_probs = target_policy.distribution(
                this_tfagents_episode).action.probs_parameter()

            for step_num in range(tf.shape(valid_steps)[1] - 1):
                this_step = tf.nest.map_structure(
                    lambda t: t[episode_num, step_num], episodes)
                next_step = tf.nest.map_structure(
                    lambda t: t[episode_num, step_num + 1], episodes)
                if this_step.is_last() or not valid_steps[episode_num,
                                                          step_num]:
                    continue

                weight = 1.0
                nu_index = self._get_index(this_step.observation,
                                           this_step.action)
                td_residuals[nu_index, nu_index] += weight
                total_weights[nu_index] += weight

                policy_ratio = 1.0
                if not self._solve_for_state_action_ratio:
                    policy_ratio = tf.exp(
                        episode_target_log_probabilities[step_num] -
                        this_step.get_log_probability())

                # Need to weight next nu by importance weight.
                next_weight = (weight if self._solve_for_state_action_ratio
                               else policy_ratio * weight)
                next_probs = episode_target_probs[step_num + 1]
                for next_action, next_prob in enumerate(next_probs):
                    next_nu_index = self._get_index(next_step.observation,
                                                    next_action)
                    td_residuals[next_nu_index,
                                 nu_index] += (-next_prob * self._gamma *
                                               next_weight)

                initial_probs = episode_target_probs[0]
                for initial_action, initial_prob in enumerate(initial_probs):
                    initial_nu_index = self._get_index(first_step.observation,
                                                       initial_action)
                    initial_weights[initial_nu_index] += weight * initial_prob

        td_residuals /= np.sqrt(regularizer + total_weights)[None, :]
        td_errors = np.dot(td_residuals, td_residuals.T)
        self._nu = np.linalg.solve(
            td_errors + regularizer * np.eye(self._dimension),
            (1 - self._gamma) * initial_weights)
        self._zeta = np.dot(
            self._nu, td_residuals) / np.sqrt(regularizer + total_weights)
        return self.estimate_average_reward(dataset, target_policy)
Example #14
0
    def get_is_weighted_reward_samples(self,
                                       dataset: dataset_lib.OffpolicyDataset,
                                       target_policy: tf_policy.TFPolicy,
                                       episode_limit: Optional[int] = None,
                                       eps: Optional[float] = 1e-8):
        """Get the IS weighted reweard samples."""
        episodes, valid_steps = dataset.get_all_episodes(limit=episode_limit)
        total_num_steps_per_episode = tf.shape(valid_steps)[1] - 1
        num_episodes = tf.shape(valid_steps)[0]
        num_samples = num_episodes * total_num_steps_per_episode

        init_env_step = tf.nest.map_structure(lambda t: t[:, 0, ...], episodes)
        env_step = tf.nest.map_structure(
            lambda t: tf.squeeze(
                tf.reshape(t[:, 0:total_num_steps_per_episode, ...],
                           [num_samples, -1])), episodes)
        next_env_step = tf.nest.map_structure(
            lambda t: tf.squeeze(
                tf.reshape(t[:, 1:1 + total_num_steps_per_episode, ...],
                           [num_samples, -1])), episodes)
        tfagents_env_step = dataset_lib.convert_to_tfagents_timestep(env_step)

        gamma_weights = tf.reshape(
            tf.pow(self._gamma, tf.cast(env_step.step_num, tf.float32)),
            [num_episodes, total_num_steps_per_episode])

        rewards = (-self._get_q_value(env_step) + self._reward_fn(env_step) +
                   self._gamma * next_env_step.discount *
                   self._get_v_value(next_env_step, target_policy))
        rewards = tf.reshape(rewards,
                             [num_episodes, total_num_steps_per_episode])

        init_values = self._get_v_value(init_env_step, target_policy)
        init_offset = (1 - self._gamma) * init_values

        target_log_probabilities = target_policy.distribution(
            tfagents_env_step).action.log_prob(env_step.action)
        if tf.rank(target_log_probabilities) > 1:
            target_log_probabilities = tf.reduce_sum(target_log_probabilities,
                                                     -1)
        if self._policy_network is not None:
            baseline_policy_log_probability = self._get_log_prob(
                self._policy_network, env_step)
            if tf.rank(baseline_policy_log_probability) > 1:
                baseline_policy_log_probability = tf.reduce_sum(
                    baseline_policy_log_probability, -1)
            policy_log_ratios = tf.reshape(
                tf.maximum(
                    -1.0 / eps, target_log_probabilities -
                    baseline_policy_log_probability),
                [num_episodes, total_num_steps_per_episode])
        else:
            policy_log_ratios = tf.reshape(
                tf.maximum(
                    -1.0 / eps,
                    target_log_probabilities - env_step.get_log_probability()),
                [num_episodes, total_num_steps_per_episode])
        valid_steps_in = valid_steps[:, 0:total_num_steps_per_episode]
        mask = tf.cast(
            tf.logical_and(valid_steps_in, episodes.discount[:, :-1] > 0.),
            tf.float32)

        masked_rewards = tf.where(mask > 0, rewards, tf.zeros_like(rewards))
        clipped_policy_log_ratios = mask * self.clip_log_factor(
            policy_log_ratios)

        if self._mode in ['trajectory-wise', 'weighted-trajectory-wise']:
            trajectory_avg_rewards = tf.reduce_sum(
                masked_rewards * gamma_weights, axis=1) / tf.reduce_sum(
                    gamma_weights, axis=1)
            trajectory_log_ratios = tf.reduce_sum(clipped_policy_log_ratios,
                                                  axis=1)
            if self._mode == 'trajectory-wise':
                trajectory_avg_rewards *= tf.exp(trajectory_log_ratios)
                return init_offset + trajectory_avg_rewards
            else:
                offset = tf.reduce_max(trajectory_log_ratios)
                normalized_clipped_ratios = tf.exp(trajectory_log_ratios -
                                                   offset)
                normalized_clipped_ratios /= tf.maximum(
                    eps, tf.reduce_mean(normalized_clipped_ratios))
                trajectory_avg_rewards *= normalized_clipped_ratios
                return init_offset + trajectory_avg_rewards

        elif self._mode in ['step-wise', 'weighted-step-wise']:
            trajectory_log_ratios = mask * tf.cumsum(policy_log_ratios, axis=1)
            if self._mode == 'step-wise':
                trajectory_avg_rewards = tf.reduce_sum(
                    masked_rewards * gamma_weights *
                    tf.exp(trajectory_log_ratios),
                    axis=1) / tf.reduce_sum(gamma_weights, axis=1)
                return init_offset + trajectory_avg_rewards
            else:
                # Average over data, for each time step.
                offset = tf.reduce_max(trajectory_log_ratios,
                                       axis=0)  # TODO: Handle mask.
                normalized_imp_weights = tf.exp(trajectory_log_ratios - offset)
                normalized_imp_weights /= tf.maximum(
                    eps,
                    tf.reduce_sum(mask * normalized_imp_weights, axis=0) /
                    tf.maximum(eps, tf.reduce_sum(mask, axis=0)))[None, :]

                trajectory_avg_rewards = tf.reduce_sum(
                    masked_rewards * gamma_weights * normalized_imp_weights,
                    axis=1) / tf.reduce_sum(gamma_weights, axis=1)
                return init_offset + trajectory_avg_rewards
        else:
            ValueError('Estimator is not implemented!')
    def prepare_dataset(self, dataset: dataset_lib.OffpolicyDataset,
                        target_policy: tf_policy.TFPolicy):
        episodes, valid_steps = dataset.get_all_episodes()
        tfagents_episodes = dataset_lib.convert_to_tfagents_timestep(episodes)

        for episode_num in range(tf.shape(valid_steps)[0]):
            # Precompute probabilites for this episode.
            this_episode = tf.nest.map_structure(lambda t: t[episode_num],
                                                 episodes)
            first_step = tf.nest.map_structure(lambda t: t[0], this_episode)
            this_tfagents_episode = dataset_lib.convert_to_tfagents_timestep(
                this_episode)
            episode_target_log_probabilities = target_policy.distribution(
                this_tfagents_episode).action.log_prob(this_episode.action)
            episode_target_probs = target_policy.distribution(
                this_tfagents_episode).action.probs_parameter()

            for step_num in range(tf.shape(valid_steps)[1] - 1):
                this_step = tf.nest.map_structure(
                    lambda t: t[episode_num, step_num], episodes)
                next_step = tf.nest.map_structure(
                    lambda t: t[episode_num, step_num + 1], episodes)
                if this_step.is_last() or not valid_steps[episode_num,
                                                          step_num]:
                    continue

                weight = 1.0
                nu_index = self._get_index(this_step.observation,
                                           this_step.action)
                self._td_residuals[nu_index, nu_index] += -weight
                self._total_weights[nu_index] += weight

                policy_ratio = 1.0
                if not self._solve_for_state_action_ratio:
                    policy_ratio = tf.exp(
                        episode_target_log_probabilities[step_num] -
                        this_step.get_log_probability())

                # Need to weight next nu by importance weight.
                next_weight = (weight if self._solve_for_state_action_ratio
                               else policy_ratio * weight)
                next_probs = episode_target_probs[step_num + 1]
                for next_action, next_prob in enumerate(next_probs):
                    next_nu_index = self._get_index(next_step.observation,
                                                    next_action)
                    self._td_residuals[next_nu_index,
                                       nu_index] += (next_prob * self._gamma *
                                                     next_weight)

                initial_probs = episode_target_probs[0]
                for initial_action, initial_prob in enumerate(initial_probs):
                    initial_nu_index = self._get_index(first_step.observation,
                                                       initial_action)
                    self._initial_weights[
                        initial_nu_index] += weight * initial_prob

        self._initial_weights = tf.cast(self._initial_weights, tf.float32)
        self._total_weights = tf.cast(self._total_weights, tf.float32)
        self._td_residuals = self._td_residuals / np.sqrt(
            1e-8 + self._total_weights)[None, :]
        self._td_errors = tf.cast(
            np.dot(self._td_residuals, self._td_residuals.T), tf.float32)
        self._td_residuals = tf.cast(self._td_residuals, tf.float32)
Example #16
0
  def solve(self,
            dataset: dataset_lib.OffpolicyDataset,
            target_policy: tf_policy.TFPolicy,
            regularizer: float = 1e-8):
    """Solves for Q-values and then approximates target policy value.

    Args:
      dataset: The dataset to sample experience from.
      target_policy: The policy whose value we want to estimate.
      regularizer: A small constant to add before dividing.

    Returns:
      Estimated average per-step reward of the target policy.
    """
    num_estimates = 1 + int(self._num_qvalues)
    transition_matrix = np.zeros(
        [self._dimension, self._dimension, num_estimates])
    reward_vector = np.zeros(
        [self._dimension, num_estimates, self._num_perturbations])
    total_weights = np.zeros([self._dimension, num_estimates])

    episodes, valid_steps = dataset.get_all_episodes(limit=self._limit_episodes)
    #all_rewards = self._reward_fn(episodes)
    #reward_std = np.ma.MaskedArray(all_rewards, valid_steps).std()
    tfagents_episodes = dataset_lib.convert_to_tfagents_timestep(episodes)

    sample_weights = np.array(valid_steps, dtype=np.int64)
    if not self._bootstrap or self._num_qvalues is None:
      sample_weights = (
          sample_weights[:, :, None] * np.ones([1, 1, num_estimates]))
    else:
      probs = np.reshape(sample_weights, [-1]) / np.sum(sample_weights)
      weights = np.random.multinomial(
          np.sum(sample_weights), probs,
          size=self._num_qvalues).astype(np.float32)
      weights = np.reshape(
          np.transpose(weights),
          list(np.shape(sample_weights)) + [self._num_qvalues])
      sample_weights = np.concatenate([sample_weights[:, :, None], weights],
                                      axis=-1)

    for episode_num in range(tf.shape(valid_steps)[0]):
      # Precompute probabilites for this episode.
      this_episode = tf.nest.map_structure(lambda t: t[episode_num], episodes)
      this_tfagents_episode = dataset_lib.convert_to_tfagents_timestep(
          this_episode)
      episode_target_log_probabilities = target_policy.distribution(
          this_tfagents_episode).action.log_prob(this_episode.action)
      episode_target_probs = target_policy.distribution(
          this_tfagents_episode).action.probs_parameter()

      for step_num in range(tf.shape(valid_steps)[1] - 1):
        this_step = tf.nest.map_structure(lambda t: t[episode_num, step_num],
                                          episodes)
        next_step = tf.nest.map_structure(
            lambda t: t[episode_num, step_num + 1], episodes)
        this_tfagents_step = dataset_lib.convert_to_tfagents_timestep(this_step)
        next_tfagents_step = dataset_lib.convert_to_tfagents_timestep(next_step)
        this_weights = sample_weights[episode_num, step_num, :]
        if this_step.is_last() or not valid_steps[episode_num, step_num]:
          continue

        weight = this_weights
        this_index = self._get_index(this_step.observation, this_step.action)

        reward_vector[this_index, :, :] += np.expand_dims(
            self._reward_fn(this_step) * weight, -1)
        if self._num_qvalues is not None:
          random_noise = np.random.binomial(this_weights[1:].astype('int64'),
                                            0.5)
          reward_vector[this_index, 1:, :] += (
              self._perturbation_scale[None, :] *
              (2 * random_noise - this_weights[1:])[:, None])

        total_weights[this_index] += weight

        policy_ratio = 1.0
        if not self._solve_for_state_action_value:
          policy_ratio = tf.exp(episode_target_log_probabilities[step_num] -
                                this_step.get_log_probability())

        # Need to weight next nu by importance weight.
        next_weight = (
            weight if self._solve_for_state_action_value else policy_ratio *
            weight)
        if next_step.is_absorbing():
          next_index = -1  # Absorbing state.
          transition_matrix[this_index, next_index] += next_weight
        else:
          next_probs = episode_target_probs[step_num + 1]
          for next_action, next_prob in enumerate(next_probs):
            next_index = self._get_index(next_step.observation, next_action)
            transition_matrix[this_index, next_index] += next_prob * next_weight
    print('Done processing data.')

    transition_matrix /= (regularizer + total_weights)[:, None, :]
    reward_vector /= (regularizer + total_weights)[:, :, None]
    reward_vector[np.where(np.equal(total_weights,
                                    0.0))] = self._default_reward_value
    reward_vector[-1, :, :] = 0.0  # Terminal absorbing state has 0 reward.

    self._point_qvalues = np.linalg.solve(
        np.eye(self._dimension) - self._gamma * transition_matrix[:, :, 0],
        reward_vector[:, 0])
    if self._num_qvalues is not None:
      self._ensemble_qvalues = np.linalg.solve(
          (np.eye(self._dimension) -
           self._gamma * np.transpose(transition_matrix, [2, 0, 1])),
          np.transpose(reward_vector, [1, 0, 2]))

    return self.estimate_average_reward(dataset, target_policy)
Example #17
0
    def solve_nu_zeta(self,
                      dataset: dataset_lib.OffpolicyDataset,
                      target_policy: tf_policy.TFPolicy,
                      regularizer: float = 1e-6):
        """Solves for density ratios and then approximates target policy value.

    Args:
      dataset: The dataset to sample experience from.
      target_policy: The policy whose value we want to estimate.
      regularizer: A small constant to add to matrices before inverting them or
        to floats before taking square root.

    Returns:
      Estimated average per-step reward of the target policy.
    """

        if not hasattr(self, '_td_mat'):
            # Set up env_steps.
            episodes, valid_steps = dataset.get_all_episodes(
                limit=self._limit_episodes)
            total_num_steps_per_episode = tf.shape(valid_steps)[1] - 1
            num_episodes = tf.shape(valid_steps)[0]
            num_samples = num_episodes * total_num_steps_per_episode
            valid_and_not_last = tf.logical_and(valid_steps,
                                                episodes.discount > 0)
            valid_indices = tf.squeeze(
                tf.where(tf.reshape(valid_and_not_last[:, :-1], [-1])))

            initial_env_step = tf.nest.map_structure(
                lambda t: tf.squeeze(
                    tf.reshape(
                        tf.repeat(t[:, 0:1, ...],
                                  axis=1,
                                  repeats=total_num_steps_per_episode),
                        [num_samples, -1])), episodes)
            initial_env_step = tf.nest.map_structure(
                lambda t: tf.gather(t, valid_indices), initial_env_step)
            tfagents_initial_env_step = dataset_lib.convert_to_tfagents_timestep(
                initial_env_step)

            env_step = tf.nest.map_structure(
                lambda t: tf.squeeze(
                    tf.reshape(t[:, 0:total_num_steps_per_episode, ...],
                               [num_samples, -1])), episodes)
            env_step = tf.nest.map_structure(
                lambda t: tf.gather(t, valid_indices), env_step)
            tfagents_env_step = dataset_lib.convert_to_tfagents_timestep(
                env_step)

            next_env_step = tf.nest.map_structure(
                lambda t: tf.squeeze(
                    tf.reshape(t[:, 1:total_num_steps_per_episode + 1, ...],
                               [num_samples, -1])), episodes)
            next_env_step = tf.nest.map_structure(
                lambda t: tf.gather(t, valid_indices), next_env_step)
            tfagents_next_env_step = dataset_lib.convert_to_tfagents_timestep(
                next_env_step)

            # get probabilities
            initial_target_probs = target_policy.distribution(
                tfagents_initial_env_step).action.probs_parameter()
            next_target_probs = target_policy.distribution(
                tfagents_next_env_step).action.probs_parameter()

            # First, get the nu_loss and data weights
            #current_nu_loss = self._get_nu_loss(initial_env_step, env_step,
            #                                    next_env_step, target_policy)
            #data_weight, _ = self._get_weights(current_nu_loss)

            # # debug only and to reproduce dual dice result, DELETE
            # data_weight = tf.ones_like(data_weight)

            state_action_count = self._get_state_action_counts(env_step)
            counts = tf.reduce_sum(
                tf.one_hot(state_action_count, self._dimension), 0)
            gamma_sample = tf.pow(self._gamma,
                                  tf.cast(env_step.step_num, tf.float32))

            # # debug only and to reproduce dual dice result, DELETE
            # gamma_sample = tf.ones_like(gamma_sample)

            # now we need to expand_dims to include action space in extra dimensions
            #data_weights = tf.reshape(data_weight, [-1, self._num_limits])
            # both are data sample weights for L2 problem, needs to be normalized later
            #gamma_data_weights = tf.reshape(gamma_sample, [-1, 1]) * data_weights

            initial_states = tf.tile(
                tf.reshape(initial_env_step.observation, [-1, 1]),
                [1, self._num_actions])
            initial_actions = tf.tile(
                tf.reshape(tf.range(self._num_actions), [1, -1]),
                [initial_env_step.observation.shape[0], 1])
            initial_nu_indices = self._get_index(initial_states,
                                                 initial_actions)

            # linear term w.r.t. initial distribution
            #b_vec_2 = tf.stack([
            #    tf.reduce_sum(
            #        tf.reshape(
            #            data_weights[:, itr] / tf.reduce_sum(data_weights[:, itr]),
            #            [-1, 1]) * tf.reduce_sum(
            #                tf.one_hot(initial_nu_indices, self._dimension) *
            #                (1 - self._gamma) *
            #                tf.expand_dims(initial_target_probs, axis=-1),
            #                axis=1),
            #        axis=0) for itr in range(self._num_limits)
            #],
            #                   axis=0)

            next_states = tf.tile(
                tf.reshape(next_env_step.observation, [-1, 1]),
                [1, self._num_actions])
            next_actions = tf.tile(
                tf.reshape(tf.range(self._num_actions), [1, -1]),
                [next_env_step.observation.shape[0], 1])
            next_nu_indices = self._get_index(next_states, next_actions)
            next_nu_indices = tf.where(
                tf.expand_dims(next_env_step.is_absorbing(), -1),
                -1 * tf.ones_like(next_nu_indices), next_nu_indices)

            nu_indices = self._get_index(env_step.observation, env_step.action)

            target_log_probabilities = target_policy.distribution(
                tfagents_env_step).action.log_prob(env_step.action)
            if not self._solve_for_state_action_ratio:
                policy_ratio = tf.exp(target_log_probabilities -
                                      env_step.get_log_probability())
            else:
                policy_ratio = tf.ones([
                    target_log_probabilities.shape[0],
                ])
            policy_ratios = tf.tile(tf.reshape(policy_ratio, [-1, 1]),
                                    [1, self._num_actions])

            # the tabular feature vector
            a_vec = tf.one_hot(nu_indices, self._dimension) - tf.reduce_sum(
                self._gamma *
                tf.expand_dims(next_target_probs * policy_ratios, axis=-1) *
                tf.one_hot(next_nu_indices, self._dimension),
                axis=1)

            # linear term w.r.t. reward
            #b_vec_1 = tf.stack([
            #    tf.reduce_sum(
            #        tf.reshape(
            #            (gamma_data_weights[:, itr] /
            #             tf.reduce_sum(gamma_data_weights[:, itr])) * self._reward_fn(env_step), #/
            #            #tf.cast(state_action_count, tf.float32),
            #            [-1, 1]) * a_vec,
            #        axis=0) for itr in range(self._num_limits)
            #],
            #                   axis=0)
            # quadratic term of feature
            # Get weighted outer product by using einsum to save computing resource!
            #a_mat = tf.stack([
            #    tf.einsum(
            #        'ai, a, aj -> ij', a_vec,
            #        #1.0 / tf.cast(state_action_count, tf.float32),
            #        gamma_data_weights[:, itr] /
            #        tf.reduce_sum(gamma_data_weights[:, itr]),
            #        a_vec)
            #    for itr in range(self._num_limits)
            #],
            #                 axis=0)

            td_mat = tf.einsum('ai, a, aj -> ij',
                               tf.one_hot(nu_indices, self._dimension),
                               1.0 / tf.cast(state_action_count, tf.float32),
                               a_vec)

            weighted_rewards = policy_ratio * self._reward_fn(env_step)

            bias = tf.reduce_sum(
                tf.one_hot(nu_indices, self._dimension) *
                tf.reshape(weighted_rewards, [-1, 1]) * 1.0 /
                tf.cast(state_action_count, tf.float32)[:, None],
                axis=0)

            # Initialize
            self._nu = np.ones_like(self._nu) * bias[:, None]
            self._nu2 = np.ones_like(self._nu2) * bias[:, None]

            self._a_vec = a_vec
            self._td_mat = td_mat
            self._bias = bias
            self._weighted_rewards = weighted_rewards
            self._state_action_count = state_action_count
            self._nu_indices = nu_indices
            self._initial_nu_indices = initial_nu_indices
            self._initial_target_probs = initial_target_probs
            self._gamma_sample = gamma_sample
            self._gamma_sample = tf.ones_like(gamma_sample)

        saddle_bellman_residuals = (tf.matmul(self._a_vec, self._nu) -
                                    self._weighted_rewards[:, None])
        saddle_bellman_residuals *= -1 * self._algae_alpha_sign
        saddle_zetas = tf.gather(self._zeta, self._nu_indices)
        saddle_initial_nu_values = tf.reduce_sum(  # Average over actions.
            self._initial_target_probs[:, :, None] *
            tf.gather(self._nu, self._initial_nu_indices),
            axis=1)
        saddle_init_nu_loss = ((1 - self._gamma) * saddle_initial_nu_values *
                               self._algae_alpha_sign)

        saddle_bellman_residuals2 = (tf.matmul(self._a_vec, self._nu2) -
                                     self._weighted_rewards[:, None])
        saddle_bellman_residuals2 *= 1 * self._algae_alpha_sign
        saddle_zetas2 = tf.gather(self._zeta2, self._nu_indices)
        saddle_initial_nu_values2 = tf.reduce_sum(  # Average over actions.
            self._initial_target_probs[:, :, None] *
            tf.gather(self._nu2, self._initial_nu_indices),
            axis=1)
        saddle_init_nu_loss2 = ((1 - self._gamma) * saddle_initial_nu_values2 *
                                -1 * self._algae_alpha_sign)

        saddle_loss = 0.5 * (
            saddle_init_nu_loss + saddle_bellman_residuals * saddle_zetas +
            -tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas) +
            -saddle_init_nu_loss2 + -saddle_bellman_residuals2 * saddle_zetas2
            + tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas2))
        # Binary search to find best alpha.
        left = tf.constant([-8., -8.])
        right = tf.constant([32., 32.])
        for _ in range(16):
            mid = 0.5 * (left + right)
            self._alpha.assign(mid)
            weights, log_weights = self._get_weights(
                saddle_loss * self._gamma_sample[:, None])

            divergence = self._compute_divergence(weights, log_weights)
            divergence_violation = divergence - self._two_sided_limit
            left = tf.where(divergence_violation > 0., mid, left)
            right = tf.where(divergence_violation > 0., right, mid)
        self._alpha.assign(0.5 * (left + right))
        weights, log_weights = self._get_weights(saddle_loss *
                                                 self._gamma_sample[:, None])

        gamma_data_weights = tf.stop_gradient(weights *
                                              self._gamma_sample[:, None])
        #print(tf.concat([gamma_data_weights, saddle_loss], axis=-1))
        avg_saddle_loss = (
            tf.reduce_sum(gamma_data_weights * saddle_loss, axis=0) /
            tf.reduce_sum(gamma_data_weights, axis=0))

        weighted_state_action_count = tf.reduce_sum(
            tf.one_hot(self._nu_indices, self._dimension)[:, :, None] *
            weights[:, None, :],
            axis=0)
        weighted_state_action_count = tf.gather(weighted_state_action_count,
                                                self._nu_indices)
        my_td_mat = tf.einsum(
            'ai, ab, ab, aj -> bij',
            tf.one_hot(self._nu_indices, self._dimension),
            #1.0 / tf.cast(self._state_action_count, tf.float32),
            1.0 / weighted_state_action_count,
            weights,
            self._a_vec)
        my_bias = tf.reduce_sum(
            tf.transpose(weights)[:, :, None] *
            tf.one_hot(self._nu_indices, self._dimension)[None, :, :] *
            tf.reshape(self._weighted_rewards, [1, -1, 1]) *
            #1.0 / tf.cast(self._state_action_count, tf.float32)[None, :, None],
            1.0 / tf.transpose(weighted_state_action_count)[:, :, None],
            axis=1)

        #print('hello', saddle_initial_nu_values[:1], saddle_zetas[:3],
        #      self._nu[:2], my_bias[:, :2], saddle_loss[:4])

        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch([self._nu, self._nu2, self._alpha])
            bellman_residuals = tf.matmul(
                my_td_mat,
                tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None]
            bellman_residuals = tf.transpose(tf.squeeze(bellman_residuals, -1))
            bellman_residuals = tf.gather(bellman_residuals, self._nu_indices)
            initial_nu_values = tf.reduce_sum(  # Average over actions.
                self._initial_target_probs[:, :, None] *
                tf.gather(self._nu, self._initial_nu_indices),
                axis=1)

            bellman_residuals *= self._algae_alpha_sign

            init_nu_loss = ((1 - self._gamma) * initial_nu_values *
                            self._algae_alpha_sign)

            nu_loss = (tf.math.square(bellman_residuals) / 2.0 +
                       tf.math.abs(self._algae_alpha) * init_nu_loss)

            loss = (gamma_data_weights * nu_loss /
                    tf.reduce_sum(gamma_data_weights, axis=0, keepdims=True))

            bellman_residuals2 = tf.matmul(
                my_td_mat,
                tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :, None]
            bellman_residuals2 = tf.transpose(
                tf.squeeze(bellman_residuals2, -1))
            bellman_residuals2 = tf.gather(bellman_residuals2,
                                           self._nu_indices)
            initial_nu_values2 = tf.reduce_sum(  # Average over actions.
                self._initial_target_probs[:, :, None] *
                tf.gather(self._nu2, self._initial_nu_indices),
                axis=1)

            bellman_residuals2 *= -1 * self._algae_alpha_sign

            init_nu_loss2 = ((1 - self._gamma) * initial_nu_values2 * -1 *
                             self._algae_alpha_sign)

            nu_loss2 = (tf.math.square(bellman_residuals2) / 2.0 +
                        tf.math.abs(self._algae_alpha) * init_nu_loss2)

            loss2 = (gamma_data_weights * nu_loss2 /
                     tf.reduce_sum(gamma_data_weights, axis=0, keepdims=True))

            divergence = self._compute_divergence(weights, log_weights)
            divergence_violation = divergence - self._two_sided_limit

            alpha_loss = (-tf.exp(self._alpha) *
                          tf.stop_gradient(divergence_violation))

            extra_loss = tf.reduce_sum(tf.math.square(self._nu[-1, :]))
            extra_loss2 = tf.reduce_sum(tf.math.square(self._nu2[-1, :]))
            nu_grad = tape.gradient(loss + extra_loss, [self._nu])[0]
            nu_grad2 = tape.gradient(loss2 + extra_loss2, [self._nu2])[0]
        avg_loss = tf.reduce_sum(0.5 * (loss - loss2) /
                                 tf.math.abs(self._algae_alpha),
                                 axis=0)
        nu_jacob = tape.jacobian(nu_grad, [self._nu])[0]
        nu_hess = tf.stack(
            [nu_jacob[:, i, :, i] for i in range(self._num_limits)], axis=0)

        nu_jacob2 = tape.jacobian(nu_grad2, [self._nu2])[0]
        nu_hess2 = tf.stack(
            [nu_jacob2[:, i, :, i] for i in range(self._num_limits)], axis=0)

        for idx, div in enumerate(divergence):
            tf.summary.scalar('divergence%d' % idx, div)

        #alpha_grads = tape.gradient(alpha_loss, [self._alpha])
        #alpha_grad_op = self._alpha_optimizer.apply_gradients(
        #    zip(alpha_grads, [self._alpha]))
        #self._alpha.assign(tf.minimum(8., tf.maximum(-8., self._alpha)))

        #print(self._alpha, tf.concat([weights, nu_loss], -1))
        #regularizer = 0.1
        nu_transformed = tf.transpose(
            tf.squeeze(
                tf.linalg.solve(
                    nu_hess + regularizer * tf.eye(self._dimension),
                    tf.expand_dims(-tf.transpose(nu_grad), axis=-1))))
        self._nu = self._nu + 0.1 * nu_transformed
        nu_transformed2 = tf.transpose(
            tf.squeeze(
                tf.linalg.solve(
                    nu_hess2 + regularizer * tf.eye(self._dimension),
                    tf.expand_dims(-tf.transpose(nu_grad2), axis=-1))))
        self._nu2 = self._nu2 + 0.1 * nu_transformed2

        print(avg_loss * self._algae_alpha_sign,
              avg_saddle_loss * self._algae_alpha_sign, self._nu[:2],
              divergence)
        #print(init_nu_loss[:8], init_nu_loss[-8:])
        #print(bellman_residuals[:8])
        #print(self._nu[:3], self._zeta[:3])

        zetas = tf.matmul(my_td_mat,
                          tf.transpose(self._nu)[:, :, None]) - my_bias[:, :,
                                                                        None]
        zetas = tf.transpose(tf.squeeze(zetas, -1))
        zetas *= -self._algae_alpha_sign
        zetas /= tf.math.abs(self._algae_alpha)
        self._zeta = self._zeta + 0.1 * (zetas - self._zeta)

        zetas2 = tf.matmul(my_td_mat,
                           tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :,
                                                                          None]
        zetas2 = tf.transpose(tf.squeeze(zetas2, -1))
        zetas2 *= 1 * self._algae_alpha_sign
        zetas2 /= tf.math.abs(self._algae_alpha)
        self._zeta2 = self._zeta2 + 0.1 * (zetas2 - self._zeta2)

        #self._zeta = (
        #    tf.einsum('ij,ja-> ia', self._td_mat, self._nu) -
        #    tf.transpose(my_bias))
        #self._zeta *= -tf.reshape(self._algae_alpha_sign, [1, self._num_limits])
        #self._zeta /= tf.math.abs(self._algae_alpha)
        return [
            avg_saddle_loss * self._algae_alpha_sign,
            avg_loss * self._algae_alpha_sign, divergence
        ]
Example #18
0
  def prepare_dataset(self, dataset: dataset_lib.OffpolicyDataset,
                      target_policy: tf_policy.TFPolicy):
    """Performs pre-computations on dataset to make solving easier."""
    episodes, valid_steps = dataset.get_all_episodes(limit=self._limit_episodes)
    total_num_steps_per_episode = tf.shape(valid_steps)[1] - 1
    num_episodes = tf.shape(valid_steps)[0]
    num_samples = num_episodes * total_num_steps_per_episode
    valid_and_not_last = tf.logical_and(valid_steps, episodes.discount > 0)
    valid_indices = tf.squeeze(
        tf.where(tf.reshape(valid_and_not_last[:, :-1], [-1])))

    # Flatten all tensors so that each data sample is a tuple of
    # (initial_env_step, env_step, next_env_step).
    initial_env_step = tf.nest.map_structure(
        lambda t: tf.squeeze(
            tf.reshape(
                tf.repeat(
                    t[:, 0:1, ...], axis=1, repeats=total_num_steps_per_episode
                ), [num_samples, -1])), episodes)
    initial_env_step = tf.nest.map_structure(
        lambda t: tf.gather(t, valid_indices), initial_env_step)
    tfagents_initial_env_step = dataset_lib.convert_to_tfagents_timestep(
        initial_env_step)

    env_step = tf.nest.map_structure(
        lambda t: tf.squeeze(
            tf.reshape(t[:, 0:total_num_steps_per_episode, ...],
                       [num_samples, -1])), episodes)
    env_step = tf.nest.map_structure(lambda t: tf.gather(t, valid_indices),
                                     env_step)
    tfagents_env_step = dataset_lib.convert_to_tfagents_timestep(env_step)

    next_env_step = tf.nest.map_structure(
        lambda t: tf.squeeze(
            tf.reshape(t[:, 1:total_num_steps_per_episode + 1, ...],
                       [num_samples, -1])), episodes)
    next_env_step = tf.nest.map_structure(lambda t: tf.gather(t, valid_indices),
                                          next_env_step)
    tfagents_next_env_step = dataset_lib.convert_to_tfagents_timestep(
        next_env_step)

    # Get target probabilities for initial and next steps.
    initial_target_probs = target_policy.distribution(
        tfagents_initial_env_step).action.probs_parameter()
    next_target_probs = target_policy.distribution(
        tfagents_next_env_step).action.probs_parameter()

    # Map states and actions to indices into tabular representation.
    initial_states = tf.tile(
        tf.reshape(initial_env_step.observation, [-1, 1]),
        [1, self._num_actions])
    initial_actions = tf.tile(
        tf.reshape(tf.range(self._num_actions), [1, -1]),
        [initial_env_step.observation.shape[0], 1])
    initial_nu_indices = self._get_index(initial_states, initial_actions)

    next_states = tf.tile(
        tf.reshape(next_env_step.observation, [-1, 1]), [1, self._num_actions])
    next_actions = tf.tile(
        tf.reshape(tf.range(self._num_actions), [1, -1]),
        [next_env_step.observation.shape[0], 1])
    next_nu_indices = self._get_index(next_states, next_actions)
    next_nu_indices = tf.where(
        tf.expand_dims(next_env_step.is_absorbing(), -1),
        -1 * tf.ones_like(next_nu_indices), next_nu_indices)

    nu_indices = self._get_index(env_step.observation, env_step.action)

    target_log_probabilities = target_policy.distribution(
        tfagents_env_step).action.log_prob(env_step.action)
    if not self._solve_for_state_action_ratio:
      policy_ratio = tf.exp(target_log_probabilities -
                            env_step.get_log_probability())
    else:
      policy_ratio = tf.ones([
          target_log_probabilities.shape[0],
      ])
    policy_ratios = tf.tile(
        tf.reshape(policy_ratio, [-1, 1]), [1, self._num_actions])

    # Bellman residual matrix of size [n_data, n_dim].
    a_vec = tf.one_hot(nu_indices, self._dimension) - tf.reduce_sum(
        self._gamma *
        tf.expand_dims(next_target_probs * policy_ratios, axis=-1) *
        tf.one_hot(next_nu_indices, self._dimension),
        axis=1)

    state_action_count = self._get_state_action_counts(env_step)
    # Bellman residual matrix of size [n_dim, n_dim].
    td_mat = tf.einsum('ai, a, aj -> ij', tf.one_hot(nu_indices,
                                                     self._dimension),
                       1.0 / tf.cast(state_action_count, tf.float32), a_vec)

    # Reward vector of size [n_data].
    weighted_rewards = policy_ratio * self._reward_fn(env_step)

    # Reward vector of size [n_dim].
    bias = tf.reduce_sum(
        tf.one_hot(nu_indices, self._dimension) *
        tf.reshape(weighted_rewards, [-1, 1]) * 1.0 /
        tf.cast(state_action_count, tf.float32)[:, None],
        axis=0)

    # Initialize.
    self._nu = np.ones_like(self._nu) * bias[:, None]
    self._nu2 = np.ones_like(self._nu2) * bias[:, None]

    self._a_vec = a_vec
    self._td_mat = td_mat
    self._bias = bias
    self._weighted_rewards = weighted_rewards
    self._state_action_count = state_action_count
    self._nu_indices = nu_indices
    self._initial_nu_indices = initial_nu_indices
    self._initial_target_probs = initial_target_probs