Ejemplo n.º 1
0
def step(
    environment: TFPyEnvironment, policy: tf_policy.TFPolicy, replay_buffer: ReplayBuffer
) -> typing.Tuple[float, bool]:
    time_step = environment.current_time_step()
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)
    replay_buffer.add_batch(traj)
    return next_time_step.reward.numpy()[0], next_time_step.is_last()
Ejemplo n.º 2
0
def decorate_policy_with_particles(policy: TFPolicy,
                                   number_of_particles: int) -> TFPolicy:
    """
    Decorate a policy's `action` method to duplicate the actions of an element of the batch over a
    set of particles.

    :param policy: An instance of `tf_policy.TFPolicy` representing the agent's current policy.
    :param number_of_particles: Number of monte-carlo rollouts of each action trajectory.

    :return: A decorated policy.
    """
    assert isinstance(
        policy,
        (CrossEntropyMethodPolicy, RandomTFPolicy
         )), "Particles can only be used with state-unconditioned policies."

    def _wrapper(
        action_method: Callable[[TimeStep, NestedTensor, Optional[Seed]],
                                PolicyStep]):
        def action_method_method_wrapper(
                time_step: TimeStep,
                policy_state: NestedTensor = (),
                seed: Optional[Seed] = None) -> PolicyStep:
            """
            The incoming `time_step` has a batch size of `population_size * number_of_particles`.
            This function reduces the batch size of `time_step` to be equal to `population_size`
            only. It does not matter which observations are retained because the policy must be
            state-unconditioned.

            The reduced time step is passed to the policy, and then each action is duplicated
            `number_of_particles` times to create a batch of
            `population_size * number_of_particles` actions.
            """
            reduced_time_step = split_nested_tensors(time_step,
                                                     policy.time_step_spec,
                                                     number_of_particles)[0]

            policy_step = action_method(reduced_time_step, policy_state, seed)
            actions = policy_step.action

            tiled_actions = tile_batch(actions, number_of_particles)

            return policy_step.replace(action=tiled_actions)

        return action_method_method_wrapper

    policy.action = _wrapper(policy.action)
    return policy
Ejemplo n.º 3
0
def policy_evaluation(
    environment: TFEnvironment,
    policy: TFPolicy,
    num_episodes: int = 1,
    max_buffer_capacity: int = 200,
    use_function: bool = True,
) -> Trajectory:
    """
    Evaluate `policy` on the `environment`.

    :param environment: tf_environment instance.
    :param policy: tf_policy instance used to step the environment.
    :param num_episodes: Number of episodes to compute the metrics over.
    :param max_buffer_capacity:  Maximum capacity of replay buffer
    :param use_function: Option to enable use of `tf.function` when collecting the trajectory.
    :return: The recorded `Trajectory`.
    """

    time_step = environment.reset()
    policy_state = policy.get_initial_state(environment.batch_size)

    buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        policy.trajectory_spec,
        batch_size=environment.batch_size,
        max_length=max_buffer_capacity,
    )

    driver = dynamic_episode_driver.DynamicEpisodeDriver(
        environment,
        policy,
        observers=[buffer.add_batch],
        num_episodes=num_episodes)
    if use_function:
        common.function(driver.run)(time_step, policy_state)
    else:
        driver.run(time_step, policy_state)

    return buffer.gather_all()
Ejemplo n.º 4
0
def sample_uniformly_distributed_observations_and_get_actions(
        policy: TFPolicy, number_of_samples: int):
    """
    Sample observations from a uniform distribution over the space of observations, and then get
    corresponding actions from the policy.

    :param policy: A policy, instance of `TFPolicy`, from which observations and actions are
                   sampled.
    :param number_of_samples: Number of observation action pairs that will be sampled.

    :return: Dictionary (`dict`) consisting of 'observations' and 'actions'.
    """
    observation_distribution = create_uniform_distribution_from_spec(
        policy.time_step_spec.observation)

    observations = observation_distribution.sample((number_of_samples, ))
    rewards = tf.zeros((number_of_samples, ), dtype=tf.float32)

    time_step = transition(observations, rewards)

    actions = policy.action(time_step).action

    return {"observations": observations, "actions": actions}
Ejemplo n.º 5
0
    def __init__(self,
                 policy: tf_policy.TFPolicy,
                 batch_size: Optional[int] = None,
                 use_nest_path_signatures: bool = True,
                 seed: Optional[types.Seed] = None,
                 train_step: Optional[tf.Variable] = None,
                 input_fn_and_spec: Optional[InputFnAndSpecType] = None,
                 metadata: Optional[Dict[Text, tf.Variable]] = None):
        """Initialize PolicySaver for  TF policy `policy`.

    Args:
      policy: A TF Policy.
      batch_size: The number of batch entries the policy will process at a time.
        This must be either `None` (unknown batch size) or a python integer.
      use_nest_path_signatures: SavedModel spec signatures will be created based
        on the sructure of the specs. Otherwise all specs must have unique
        names.
      seed: Random seed for the `policy.action` call, if any (this should
        usually be `None`, except for testing).
      train_step: Variable holding the train step for the policy. The value
        saved will be set at the time `saver.save` is called. If not provided,
        train_step defaults to -1. Note since the train step must be a variable
        it is not safe to create it directly in TF1 so in that case this is a
        required parameter.
      input_fn_and_spec: A `(input_fn, tensor_spec)` tuple where input_fn is a
        function that takes inputs according to tensor_spec and converts them to
        the `(time_step, policy_state)` tuple that is used as the input to the
        action_fn. When `input_fn_and_spec` is set, `tensor_spec` is the input
        for the action signature. When `input_fn_and_spec is None`, the action
        signature takes as input `(time_step, policy_state)`.
      metadata: A dictionary of `tf.Variables` to be saved along with the
        policy.

    Raises:
      TypeError: If `policy` is not an instance of TFPolicy.
      TypeError: If `metadata` is not a dictionary of tf.Variables.
      ValueError: If use_nest_path_signatures is not used and any of the
        following `policy` specs are missing names, or the names collide:
        `policy.time_step_spec`, `policy.action_spec`,
        `policy.policy_state_spec`, `policy.info_spec`.
      ValueError: If `batch_size` is not either `None` or a python integer > 0.
    """
        if not isinstance(policy, tf_policy.TFPolicy):
            raise TypeError('policy is not a TFPolicy.  Saw: %s' %
                            type(policy))
        if (batch_size is not None
                and (not isinstance(batch_size, int) or batch_size < 1)):
            raise ValueError(
                'Expected batch_size == None or python int > 0, saw: %s' %
                (batch_size, ))

        action_fn_input_spec = (policy.time_step_spec,
                                policy.policy_state_spec)
        if use_nest_path_signatures:
            action_fn_input_spec = _rename_spec_with_nest_paths(
                action_fn_input_spec)
        else:
            _check_spec(action_fn_input_spec)

        # Make a shallow copy as we'll be making some changes in-place.
        saved_policy = tf.Module()
        saved_policy.collect_data_spec = copy.copy(policy.collect_data_spec)
        saved_policy.policy_state_spec = copy.copy(policy.policy_state_spec)

        if train_step is None:
            if not common.has_eager_been_enabled():
                raise ValueError('train_step is required in TF1 and must be a '
                                 '`tf.Variable`: %s' % train_step)
            train_step = tf.Variable(
                -1,
                trainable=False,
                dtype=tf.int64,
                aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
                shape=())
        elif not isinstance(train_step, tf.Variable):
            raise ValueError('train_step must be a TensorFlow variable: %s' %
                             train_step)

        # We will need the train step for the Checkpoint object.
        self._train_step = train_step
        saved_policy.train_step = self._train_step

        self._metadata = metadata or {}
        for key, value in self._metadata.items():
            if not isinstance(key, str):
                raise TypeError('Keys of metadata must be strings: %s' % key)
            if not isinstance(value, tf.Variable):
                raise TypeError('Values of metadata must be tf.Variable: %s' %
                                value)
        saved_policy.metadata = self._metadata

        if batch_size is None:
            get_initial_state_fn = policy.get_initial_state
            get_initial_state_input_specs = (tf.TensorSpec(
                dtype=tf.int32, shape=(), name='batch_size'), )
        else:
            get_initial_state_fn = functools.partial(policy.get_initial_state,
                                                     batch_size=batch_size)
            get_initial_state_input_specs = ()

        get_initial_state_fn = common.function()(get_initial_state_fn)

        original_action_fn = policy.action

        if seed is not None:

            def action_fn(time_step, policy_state):
                time_step = cast(ts.TimeStep, time_step)
                return original_action_fn(time_step, policy_state, seed=seed)
        else:
            action_fn = original_action_fn

        def distribution_fn(time_step, policy_state):
            """Wrapper for policy.distribution() in the SavedModel."""
            try:
                time_step = cast(ts.TimeStep, time_step)
                outs = policy.distribution(time_step=time_step,
                                           policy_state=policy_state)
                return tf.nest.map_structure(_composite_distribution, outs)
            except (TypeError, NotImplementedError) as e:
                # TODO(b/156526399): Move this to just the policy.distribution() call
                # once tfp.experimental.as_composite() properly handles LinearOperator*
                # components as well as TransformedDistributions.
                logging.warning(
                    'WARNING: Could not serialize policy.distribution() for policy '
                    '"%s". Calling saved_model.distribution() will raise the following '
                    'assertion error: %s', policy, e)

                @common.function()
                def _raise():
                    tf.Assert(False, [str(e)])
                    return ()

                outs = _raise()

        # We call get_concrete_function() for its side effect: to ensure the proper
        # ConcreteFunction is stored in the SavedModel.
        get_initial_state_fn.get_concrete_function(
            *get_initial_state_input_specs)

        train_step_fn = common.function(
            lambda: saved_policy.train_step).get_concrete_function()
        get_metadata_fn = common.function(
            lambda: saved_policy.metadata).get_concrete_function()

        batched_time_step_spec = tf.nest.map_structure(
            lambda spec: add_batch_dim(spec, [batch_size]),
            policy.time_step_spec)
        batched_time_step_spec = cast(ts.TimeStep, batched_time_step_spec)
        batched_policy_state_spec = tf.nest.map_structure(
            lambda spec: add_batch_dim(spec, [batch_size]),
            policy.policy_state_spec)

        policy_step_spec = policy.policy_step_spec
        policy_state_spec = policy.policy_state_spec

        if use_nest_path_signatures:
            batched_time_step_spec = _rename_spec_with_nest_paths(
                batched_time_step_spec)
            batched_policy_state_spec = _rename_spec_with_nest_paths(
                batched_policy_state_spec)
            policy_step_spec = _rename_spec_with_nest_paths(policy_step_spec)
            policy_state_spec = _rename_spec_with_nest_paths(policy_state_spec)
        else:
            _check_spec(batched_time_step_spec)
            _check_spec(batched_policy_state_spec)
            _check_spec(policy_step_spec)
            _check_spec(policy_state_spec)

        if input_fn_and_spec is not None:
            # Store a signature based on input_fn_and_spec
            @common.function()
            def polymorphic_action_fn(example):
                action_inputs = input_fn_and_spec[0](example)
                tf.nest.map_structure(
                    lambda spec, t: tf.Assert(spec.is_compatible_with(t[
                        0]), [t]), action_fn_input_spec, action_inputs)
                return action_fn(*action_inputs)

            @common.function()
            def polymorphic_distribution_fn(example):
                action_inputs = input_fn_and_spec[0](example)
                tf.nest.map_structure(
                    lambda spec, t: tf.Assert(spec.is_compatible_with(t[
                        0]), [t]), action_fn_input_spec, action_inputs)
                return distribution_fn(*action_inputs)

            batched_input_spec = tf.nest.map_structure(
                lambda spec: add_batch_dim(spec, [batch_size]),
                input_fn_and_spec[1])
            # We call get_concrete_function() for its side effect: to ensure the
            # proper ConcreteFunction is stored in the SavedModel.
            polymorphic_action_fn.get_concrete_function(
                example=batched_input_spec)
            polymorphic_distribution_fn.get_concrete_function(
                example=batched_input_spec)

            action_input_spec = (input_fn_and_spec[1], )

        else:
            action_input_spec = action_fn_input_spec
            if batched_policy_state_spec:
                # Store the signature with a required policy state spec
                polymorphic_action_fn = common.function()(action_fn)
                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)

                polymorphic_distribution_fn = common.function()(
                    distribution_fn)
                polymorphic_distribution_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
            else:
                # Create a polymorphic action_fn which you can call as
                #  restored.action(time_step)
                # or
                #  restored.action(time_step, ())
                # (without retracing the inner action twice)
                @common.function()
                def polymorphic_action_fn(
                        time_step, policy_state=batched_policy_state_spec):
                    return action_fn(time_step, policy_state)

                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec)

                @common.function()
                def polymorphic_distribution_fn(
                        time_step, policy_state=batched_policy_state_spec):
                    return distribution_fn(time_step, policy_state)

                polymorphic_distribution_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
                polymorphic_distribution_fn.get_concrete_function(
                    time_step=batched_time_step_spec)

        signatures = {
            # CompositeTensors aren't well supported by old-style signature
            # mechanisms, so we do not have a signature for policy.distribution.
            'action':
            _function_with_flat_signature(polymorphic_action_fn,
                                          input_specs=action_input_spec,
                                          output_spec=policy_step_spec,
                                          include_batch_dimension=True,
                                          batch_size=batch_size),
            'get_initial_state':
            _function_with_flat_signature(
                get_initial_state_fn,
                input_specs=get_initial_state_input_specs,
                output_spec=policy_state_spec,
                include_batch_dimension=False),
            'get_train_step':
            _function_with_flat_signature(train_step_fn,
                                          input_specs=(),
                                          output_spec=train_step.dtype,
                                          include_batch_dimension=False),
            'get_metadata':
            _function_with_flat_signature(get_metadata_fn,
                                          input_specs=(),
                                          output_spec=tf.nest.map_structure(
                                              lambda v: v.dtype,
                                              self._metadata),
                                          include_batch_dimension=False),
        }

        saved_policy.action = polymorphic_action_fn
        saved_policy.distribution = polymorphic_distribution_fn
        saved_policy.get_initial_state = get_initial_state_fn
        saved_policy.get_train_step = train_step_fn
        saved_policy.get_metadata = get_metadata_fn
        # Adding variables as an attribute to facilitate updating them.
        saved_policy.model_variables = policy.variables()

        # TODO(b/156779400): Move to a public API for accessing all trackable leaf
        # objects (once it's available).  For now, we have no other way of tracking
        # objects like Tables, Vocabulary files, etc.
        try:
            saved_policy._all_assets = policy._unconditional_checkpoint_dependencies  # pylint: disable=protected-access
        except AttributeError as e:
            if '_self_unconditional' in str(e):
                logging.warning(
                    'Unable to capture all trackable objects in policy "%s".  This '
                    'may be okay.  Error: %s', policy, e)
            else:
                raise e

        self._policy = saved_policy
        self._signatures = signatures
        self._action_input_spec = action_input_spec
        self._policy_step_spec = policy_step_spec
        self._policy_state_spec = policy_state_spec
Ejemplo n.º 6
0
    def solve(self,
              dataset: dataset_lib.OffpolicyDataset,
              target_policy: tf_policy.TFPolicy,
              regularizer: float = 1e-8):
        """Solves for density ratios and then approximates target policy value.

    Args:
      dataset: The dataset to sample experience from.
      target_policy: The policy whose value we want to estimate.
      regularizer: A small constant to add to matrices before inverting them or
        to floats before taking square root.

    Returns:
      Estimated average per-step reward of the target policy.
    """
        td_residuals = np.zeros([self._dimension, self._dimension])
        total_weights = np.zeros([self._dimension])
        initial_weights = np.zeros([self._dimension])

        episodes, valid_steps = dataset.get_all_episodes(limit=None)
        tfagents_episodes = dataset_lib.convert_to_tfagents_timestep(episodes)

        for episode_num in range(tf.shape(valid_steps)[0]):
            # Precompute probabilites for this episode.
            this_episode = tf.nest.map_structure(lambda t: t[episode_num],
                                                 episodes)
            first_step = tf.nest.map_structure(lambda t: t[0], this_episode)
            this_tfagents_episode = dataset_lib.convert_to_tfagents_timestep(
                this_episode)
            episode_target_log_probabilities = target_policy.distribution(
                this_tfagents_episode).action.log_prob(this_episode.action)
            episode_target_probs = target_policy.distribution(
                this_tfagents_episode).action.probs_parameter()

            for step_num in range(tf.shape(valid_steps)[1] - 1):
                this_step = tf.nest.map_structure(
                    lambda t: t[episode_num, step_num], episodes)
                next_step = tf.nest.map_structure(
                    lambda t: t[episode_num, step_num + 1], episodes)
                if this_step.is_last() or not valid_steps[episode_num,
                                                          step_num]:
                    continue

                weight = 1.0
                nu_index = self._get_index(this_step.observation,
                                           this_step.action)
                td_residuals[nu_index, nu_index] += weight
                total_weights[nu_index] += weight

                policy_ratio = 1.0
                if not self._solve_for_state_action_ratio:
                    policy_ratio = tf.exp(
                        episode_target_log_probabilities[step_num] -
                        this_step.get_log_probability())

                # Need to weight next nu by importance weight.
                next_weight = (weight if self._solve_for_state_action_ratio
                               else policy_ratio * weight)
                next_probs = episode_target_probs[step_num + 1]
                for next_action, next_prob in enumerate(next_probs):
                    next_nu_index = self._get_index(next_step.observation,
                                                    next_action)
                    td_residuals[next_nu_index,
                                 nu_index] += (-next_prob * self._gamma *
                                               next_weight)

                initial_probs = episode_target_probs[0]
                for initial_action, initial_prob in enumerate(initial_probs):
                    initial_nu_index = self._get_index(first_step.observation,
                                                       initial_action)
                    initial_weights[initial_nu_index] += weight * initial_prob

        td_residuals /= np.sqrt(regularizer + total_weights)[None, :]
        td_errors = np.dot(td_residuals, td_residuals.T)
        self._nu = np.linalg.solve(
            td_errors + regularizer * np.eye(self._dimension),
            (1 - self._gamma) * initial_weights)
        self._zeta = np.dot(
            self._nu, td_residuals) / np.sqrt(regularizer + total_weights)
        return self.estimate_average_reward(dataset, target_policy)
Ejemplo n.º 7
0
    def solve_nu_zeta(self,
                      dataset: dataset_lib.OffpolicyDataset,
                      target_policy: tf_policy.TFPolicy,
                      regularizer: float = 1e-6):
        """Solves for density ratios and then approximates target policy value.

    Args:
      dataset: The dataset to sample experience from.
      target_policy: The policy whose value we want to estimate.
      regularizer: A small constant to add to matrices before inverting them or
        to floats before taking square root.

    Returns:
      Estimated average per-step reward of the target policy.
    """

        if not hasattr(self, '_td_mat'):
            # Set up env_steps.
            episodes, valid_steps = dataset.get_all_episodes(
                limit=self._limit_episodes)
            total_num_steps_per_episode = tf.shape(valid_steps)[1] - 1
            num_episodes = tf.shape(valid_steps)[0]
            num_samples = num_episodes * total_num_steps_per_episode
            valid_and_not_last = tf.logical_and(valid_steps,
                                                episodes.discount > 0)
            valid_indices = tf.squeeze(
                tf.where(tf.reshape(valid_and_not_last[:, :-1], [-1])))

            initial_env_step = tf.nest.map_structure(
                lambda t: tf.squeeze(
                    tf.reshape(
                        tf.repeat(t[:, 0:1, ...],
                                  axis=1,
                                  repeats=total_num_steps_per_episode),
                        [num_samples, -1])), episodes)
            initial_env_step = tf.nest.map_structure(
                lambda t: tf.gather(t, valid_indices), initial_env_step)
            tfagents_initial_env_step = dataset_lib.convert_to_tfagents_timestep(
                initial_env_step)

            env_step = tf.nest.map_structure(
                lambda t: tf.squeeze(
                    tf.reshape(t[:, 0:total_num_steps_per_episode, ...],
                               [num_samples, -1])), episodes)
            env_step = tf.nest.map_structure(
                lambda t: tf.gather(t, valid_indices), env_step)
            tfagents_env_step = dataset_lib.convert_to_tfagents_timestep(
                env_step)

            next_env_step = tf.nest.map_structure(
                lambda t: tf.squeeze(
                    tf.reshape(t[:, 1:total_num_steps_per_episode + 1, ...],
                               [num_samples, -1])), episodes)
            next_env_step = tf.nest.map_structure(
                lambda t: tf.gather(t, valid_indices), next_env_step)
            tfagents_next_env_step = dataset_lib.convert_to_tfagents_timestep(
                next_env_step)

            # get probabilities
            initial_target_probs = target_policy.distribution(
                tfagents_initial_env_step).action.probs_parameter()
            next_target_probs = target_policy.distribution(
                tfagents_next_env_step).action.probs_parameter()

            # First, get the nu_loss and data weights
            #current_nu_loss = self._get_nu_loss(initial_env_step, env_step,
            #                                    next_env_step, target_policy)
            #data_weight, _ = self._get_weights(current_nu_loss)

            # # debug only and to reproduce dual dice result, DELETE
            # data_weight = tf.ones_like(data_weight)

            state_action_count = self._get_state_action_counts(env_step)
            counts = tf.reduce_sum(
                tf.one_hot(state_action_count, self._dimension), 0)
            gamma_sample = tf.pow(self._gamma,
                                  tf.cast(env_step.step_num, tf.float32))

            # # debug only and to reproduce dual dice result, DELETE
            # gamma_sample = tf.ones_like(gamma_sample)

            # now we need to expand_dims to include action space in extra dimensions
            #data_weights = tf.reshape(data_weight, [-1, self._num_limits])
            # both are data sample weights for L2 problem, needs to be normalized later
            #gamma_data_weights = tf.reshape(gamma_sample, [-1, 1]) * data_weights

            initial_states = tf.tile(
                tf.reshape(initial_env_step.observation, [-1, 1]),
                [1, self._num_actions])
            initial_actions = tf.tile(
                tf.reshape(tf.range(self._num_actions), [1, -1]),
                [initial_env_step.observation.shape[0], 1])
            initial_nu_indices = self._get_index(initial_states,
                                                 initial_actions)

            # linear term w.r.t. initial distribution
            #b_vec_2 = tf.stack([
            #    tf.reduce_sum(
            #        tf.reshape(
            #            data_weights[:, itr] / tf.reduce_sum(data_weights[:, itr]),
            #            [-1, 1]) * tf.reduce_sum(
            #                tf.one_hot(initial_nu_indices, self._dimension) *
            #                (1 - self._gamma) *
            #                tf.expand_dims(initial_target_probs, axis=-1),
            #                axis=1),
            #        axis=0) for itr in range(self._num_limits)
            #],
            #                   axis=0)

            next_states = tf.tile(
                tf.reshape(next_env_step.observation, [-1, 1]),
                [1, self._num_actions])
            next_actions = tf.tile(
                tf.reshape(tf.range(self._num_actions), [1, -1]),
                [next_env_step.observation.shape[0], 1])
            next_nu_indices = self._get_index(next_states, next_actions)
            next_nu_indices = tf.where(
                tf.expand_dims(next_env_step.is_absorbing(), -1),
                -1 * tf.ones_like(next_nu_indices), next_nu_indices)

            nu_indices = self._get_index(env_step.observation, env_step.action)

            target_log_probabilities = target_policy.distribution(
                tfagents_env_step).action.log_prob(env_step.action)
            if not self._solve_for_state_action_ratio:
                policy_ratio = tf.exp(target_log_probabilities -
                                      env_step.get_log_probability())
            else:
                policy_ratio = tf.ones([
                    target_log_probabilities.shape[0],
                ])
            policy_ratios = tf.tile(tf.reshape(policy_ratio, [-1, 1]),
                                    [1, self._num_actions])

            # the tabular feature vector
            a_vec = tf.one_hot(nu_indices, self._dimension) - tf.reduce_sum(
                self._gamma *
                tf.expand_dims(next_target_probs * policy_ratios, axis=-1) *
                tf.one_hot(next_nu_indices, self._dimension),
                axis=1)

            # linear term w.r.t. reward
            #b_vec_1 = tf.stack([
            #    tf.reduce_sum(
            #        tf.reshape(
            #            (gamma_data_weights[:, itr] /
            #             tf.reduce_sum(gamma_data_weights[:, itr])) * self._reward_fn(env_step), #/
            #            #tf.cast(state_action_count, tf.float32),
            #            [-1, 1]) * a_vec,
            #        axis=0) for itr in range(self._num_limits)
            #],
            #                   axis=0)
            # quadratic term of feature
            # Get weighted outer product by using einsum to save computing resource!
            #a_mat = tf.stack([
            #    tf.einsum(
            #        'ai, a, aj -> ij', a_vec,
            #        #1.0 / tf.cast(state_action_count, tf.float32),
            #        gamma_data_weights[:, itr] /
            #        tf.reduce_sum(gamma_data_weights[:, itr]),
            #        a_vec)
            #    for itr in range(self._num_limits)
            #],
            #                 axis=0)

            td_mat = tf.einsum('ai, a, aj -> ij',
                               tf.one_hot(nu_indices, self._dimension),
                               1.0 / tf.cast(state_action_count, tf.float32),
                               a_vec)

            weighted_rewards = policy_ratio * self._reward_fn(env_step)

            bias = tf.reduce_sum(
                tf.one_hot(nu_indices, self._dimension) *
                tf.reshape(weighted_rewards, [-1, 1]) * 1.0 /
                tf.cast(state_action_count, tf.float32)[:, None],
                axis=0)

            # Initialize
            self._nu = np.ones_like(self._nu) * bias[:, None]
            self._nu2 = np.ones_like(self._nu2) * bias[:, None]

            self._a_vec = a_vec
            self._td_mat = td_mat
            self._bias = bias
            self._weighted_rewards = weighted_rewards
            self._state_action_count = state_action_count
            self._nu_indices = nu_indices
            self._initial_nu_indices = initial_nu_indices
            self._initial_target_probs = initial_target_probs
            self._gamma_sample = gamma_sample
            self._gamma_sample = tf.ones_like(gamma_sample)

        saddle_bellman_residuals = (tf.matmul(self._a_vec, self._nu) -
                                    self._weighted_rewards[:, None])
        saddle_bellman_residuals *= -1 * self._algae_alpha_sign
        saddle_zetas = tf.gather(self._zeta, self._nu_indices)
        saddle_initial_nu_values = tf.reduce_sum(  # Average over actions.
            self._initial_target_probs[:, :, None] *
            tf.gather(self._nu, self._initial_nu_indices),
            axis=1)
        saddle_init_nu_loss = ((1 - self._gamma) * saddle_initial_nu_values *
                               self._algae_alpha_sign)

        saddle_bellman_residuals2 = (tf.matmul(self._a_vec, self._nu2) -
                                     self._weighted_rewards[:, None])
        saddle_bellman_residuals2 *= 1 * self._algae_alpha_sign
        saddle_zetas2 = tf.gather(self._zeta2, self._nu_indices)
        saddle_initial_nu_values2 = tf.reduce_sum(  # Average over actions.
            self._initial_target_probs[:, :, None] *
            tf.gather(self._nu2, self._initial_nu_indices),
            axis=1)
        saddle_init_nu_loss2 = ((1 - self._gamma) * saddle_initial_nu_values2 *
                                -1 * self._algae_alpha_sign)

        saddle_loss = 0.5 * (
            saddle_init_nu_loss + saddle_bellman_residuals * saddle_zetas +
            -tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas) +
            -saddle_init_nu_loss2 + -saddle_bellman_residuals2 * saddle_zetas2
            + tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas2))
        # Binary search to find best alpha.
        left = tf.constant([-8., -8.])
        right = tf.constant([32., 32.])
        for _ in range(16):
            mid = 0.5 * (left + right)
            self._alpha.assign(mid)
            weights, log_weights = self._get_weights(
                saddle_loss * self._gamma_sample[:, None])

            divergence = self._compute_divergence(weights, log_weights)
            divergence_violation = divergence - self._two_sided_limit
            left = tf.where(divergence_violation > 0., mid, left)
            right = tf.where(divergence_violation > 0., right, mid)
        self._alpha.assign(0.5 * (left + right))
        weights, log_weights = self._get_weights(saddle_loss *
                                                 self._gamma_sample[:, None])

        gamma_data_weights = tf.stop_gradient(weights *
                                              self._gamma_sample[:, None])
        #print(tf.concat([gamma_data_weights, saddle_loss], axis=-1))
        avg_saddle_loss = (
            tf.reduce_sum(gamma_data_weights * saddle_loss, axis=0) /
            tf.reduce_sum(gamma_data_weights, axis=0))

        weighted_state_action_count = tf.reduce_sum(
            tf.one_hot(self._nu_indices, self._dimension)[:, :, None] *
            weights[:, None, :],
            axis=0)
        weighted_state_action_count = tf.gather(weighted_state_action_count,
                                                self._nu_indices)
        my_td_mat = tf.einsum(
            'ai, ab, ab, aj -> bij',
            tf.one_hot(self._nu_indices, self._dimension),
            #1.0 / tf.cast(self._state_action_count, tf.float32),
            1.0 / weighted_state_action_count,
            weights,
            self._a_vec)
        my_bias = tf.reduce_sum(
            tf.transpose(weights)[:, :, None] *
            tf.one_hot(self._nu_indices, self._dimension)[None, :, :] *
            tf.reshape(self._weighted_rewards, [1, -1, 1]) *
            #1.0 / tf.cast(self._state_action_count, tf.float32)[None, :, None],
            1.0 / tf.transpose(weighted_state_action_count)[:, :, None],
            axis=1)

        #print('hello', saddle_initial_nu_values[:1], saddle_zetas[:3],
        #      self._nu[:2], my_bias[:, :2], saddle_loss[:4])

        with tf.GradientTape(watch_accessed_variables=False,
                             persistent=True) as tape:
            tape.watch([self._nu, self._nu2, self._alpha])
            bellman_residuals = tf.matmul(
                my_td_mat,
                tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None]
            bellman_residuals = tf.transpose(tf.squeeze(bellman_residuals, -1))
            bellman_residuals = tf.gather(bellman_residuals, self._nu_indices)
            initial_nu_values = tf.reduce_sum(  # Average over actions.
                self._initial_target_probs[:, :, None] *
                tf.gather(self._nu, self._initial_nu_indices),
                axis=1)

            bellman_residuals *= self._algae_alpha_sign

            init_nu_loss = ((1 - self._gamma) * initial_nu_values *
                            self._algae_alpha_sign)

            nu_loss = (tf.math.square(bellman_residuals) / 2.0 +
                       tf.math.abs(self._algae_alpha) * init_nu_loss)

            loss = (gamma_data_weights * nu_loss /
                    tf.reduce_sum(gamma_data_weights, axis=0, keepdims=True))

            bellman_residuals2 = tf.matmul(
                my_td_mat,
                tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :, None]
            bellman_residuals2 = tf.transpose(
                tf.squeeze(bellman_residuals2, -1))
            bellman_residuals2 = tf.gather(bellman_residuals2,
                                           self._nu_indices)
            initial_nu_values2 = tf.reduce_sum(  # Average over actions.
                self._initial_target_probs[:, :, None] *
                tf.gather(self._nu2, self._initial_nu_indices),
                axis=1)

            bellman_residuals2 *= -1 * self._algae_alpha_sign

            init_nu_loss2 = ((1 - self._gamma) * initial_nu_values2 * -1 *
                             self._algae_alpha_sign)

            nu_loss2 = (tf.math.square(bellman_residuals2) / 2.0 +
                        tf.math.abs(self._algae_alpha) * init_nu_loss2)

            loss2 = (gamma_data_weights * nu_loss2 /
                     tf.reduce_sum(gamma_data_weights, axis=0, keepdims=True))

            divergence = self._compute_divergence(weights, log_weights)
            divergence_violation = divergence - self._two_sided_limit

            alpha_loss = (-tf.exp(self._alpha) *
                          tf.stop_gradient(divergence_violation))

            extra_loss = tf.reduce_sum(tf.math.square(self._nu[-1, :]))
            extra_loss2 = tf.reduce_sum(tf.math.square(self._nu2[-1, :]))
            nu_grad = tape.gradient(loss + extra_loss, [self._nu])[0]
            nu_grad2 = tape.gradient(loss2 + extra_loss2, [self._nu2])[0]
        avg_loss = tf.reduce_sum(0.5 * (loss - loss2) /
                                 tf.math.abs(self._algae_alpha),
                                 axis=0)
        nu_jacob = tape.jacobian(nu_grad, [self._nu])[0]
        nu_hess = tf.stack(
            [nu_jacob[:, i, :, i] for i in range(self._num_limits)], axis=0)

        nu_jacob2 = tape.jacobian(nu_grad2, [self._nu2])[0]
        nu_hess2 = tf.stack(
            [nu_jacob2[:, i, :, i] for i in range(self._num_limits)], axis=0)

        for idx, div in enumerate(divergence):
            tf.summary.scalar('divergence%d' % idx, div)

        #alpha_grads = tape.gradient(alpha_loss, [self._alpha])
        #alpha_grad_op = self._alpha_optimizer.apply_gradients(
        #    zip(alpha_grads, [self._alpha]))
        #self._alpha.assign(tf.minimum(8., tf.maximum(-8., self._alpha)))

        #print(self._alpha, tf.concat([weights, nu_loss], -1))
        #regularizer = 0.1
        nu_transformed = tf.transpose(
            tf.squeeze(
                tf.linalg.solve(
                    nu_hess + regularizer * tf.eye(self._dimension),
                    tf.expand_dims(-tf.transpose(nu_grad), axis=-1))))
        self._nu = self._nu + 0.1 * nu_transformed
        nu_transformed2 = tf.transpose(
            tf.squeeze(
                tf.linalg.solve(
                    nu_hess2 + regularizer * tf.eye(self._dimension),
                    tf.expand_dims(-tf.transpose(nu_grad2), axis=-1))))
        self._nu2 = self._nu2 + 0.1 * nu_transformed2

        print(avg_loss * self._algae_alpha_sign,
              avg_saddle_loss * self._algae_alpha_sign, self._nu[:2],
              divergence)
        #print(init_nu_loss[:8], init_nu_loss[-8:])
        #print(bellman_residuals[:8])
        #print(self._nu[:3], self._zeta[:3])

        zetas = tf.matmul(my_td_mat,
                          tf.transpose(self._nu)[:, :, None]) - my_bias[:, :,
                                                                        None]
        zetas = tf.transpose(tf.squeeze(zetas, -1))
        zetas *= -self._algae_alpha_sign
        zetas /= tf.math.abs(self._algae_alpha)
        self._zeta = self._zeta + 0.1 * (zetas - self._zeta)

        zetas2 = tf.matmul(my_td_mat,
                           tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :,
                                                                          None]
        zetas2 = tf.transpose(tf.squeeze(zetas2, -1))
        zetas2 *= 1 * self._algae_alpha_sign
        zetas2 /= tf.math.abs(self._algae_alpha)
        self._zeta2 = self._zeta2 + 0.1 * (zetas2 - self._zeta2)

        #self._zeta = (
        #    tf.einsum('ij,ja-> ia', self._td_mat, self._nu) -
        #    tf.transpose(my_bias))
        #self._zeta *= -tf.reshape(self._algae_alpha_sign, [1, self._num_limits])
        #self._zeta /= tf.math.abs(self._algae_alpha)
        return [
            avg_saddle_loss * self._algae_alpha_sign,
            avg_loss * self._algae_alpha_sign, divergence
        ]
Ejemplo n.º 8
0
  def prepare_dataset(self, dataset: dataset_lib.OffpolicyDataset,
                      target_policy: tf_policy.TFPolicy):
    """Performs pre-computations on dataset to make solving easier."""
    episodes, valid_steps = dataset.get_all_episodes(limit=self._limit_episodes)
    total_num_steps_per_episode = tf.shape(valid_steps)[1] - 1
    num_episodes = tf.shape(valid_steps)[0]
    num_samples = num_episodes * total_num_steps_per_episode
    valid_and_not_last = tf.logical_and(valid_steps, episodes.discount > 0)
    valid_indices = tf.squeeze(
        tf.where(tf.reshape(valid_and_not_last[:, :-1], [-1])))

    # Flatten all tensors so that each data sample is a tuple of
    # (initial_env_step, env_step, next_env_step).
    initial_env_step = tf.nest.map_structure(
        lambda t: tf.squeeze(
            tf.reshape(
                tf.repeat(
                    t[:, 0:1, ...], axis=1, repeats=total_num_steps_per_episode
                ), [num_samples, -1])), episodes)
    initial_env_step = tf.nest.map_structure(
        lambda t: tf.gather(t, valid_indices), initial_env_step)
    tfagents_initial_env_step = dataset_lib.convert_to_tfagents_timestep(
        initial_env_step)

    env_step = tf.nest.map_structure(
        lambda t: tf.squeeze(
            tf.reshape(t[:, 0:total_num_steps_per_episode, ...],
                       [num_samples, -1])), episodes)
    env_step = tf.nest.map_structure(lambda t: tf.gather(t, valid_indices),
                                     env_step)
    tfagents_env_step = dataset_lib.convert_to_tfagents_timestep(env_step)

    next_env_step = tf.nest.map_structure(
        lambda t: tf.squeeze(
            tf.reshape(t[:, 1:total_num_steps_per_episode + 1, ...],
                       [num_samples, -1])), episodes)
    next_env_step = tf.nest.map_structure(lambda t: tf.gather(t, valid_indices),
                                          next_env_step)
    tfagents_next_env_step = dataset_lib.convert_to_tfagents_timestep(
        next_env_step)

    # Get target probabilities for initial and next steps.
    initial_target_probs = target_policy.distribution(
        tfagents_initial_env_step).action.probs_parameter()
    next_target_probs = target_policy.distribution(
        tfagents_next_env_step).action.probs_parameter()

    # Map states and actions to indices into tabular representation.
    initial_states = tf.tile(
        tf.reshape(initial_env_step.observation, [-1, 1]),
        [1, self._num_actions])
    initial_actions = tf.tile(
        tf.reshape(tf.range(self._num_actions), [1, -1]),
        [initial_env_step.observation.shape[0], 1])
    initial_nu_indices = self._get_index(initial_states, initial_actions)

    next_states = tf.tile(
        tf.reshape(next_env_step.observation, [-1, 1]), [1, self._num_actions])
    next_actions = tf.tile(
        tf.reshape(tf.range(self._num_actions), [1, -1]),
        [next_env_step.observation.shape[0], 1])
    next_nu_indices = self._get_index(next_states, next_actions)
    next_nu_indices = tf.where(
        tf.expand_dims(next_env_step.is_absorbing(), -1),
        -1 * tf.ones_like(next_nu_indices), next_nu_indices)

    nu_indices = self._get_index(env_step.observation, env_step.action)

    target_log_probabilities = target_policy.distribution(
        tfagents_env_step).action.log_prob(env_step.action)
    if not self._solve_for_state_action_ratio:
      policy_ratio = tf.exp(target_log_probabilities -
                            env_step.get_log_probability())
    else:
      policy_ratio = tf.ones([
          target_log_probabilities.shape[0],
      ])
    policy_ratios = tf.tile(
        tf.reshape(policy_ratio, [-1, 1]), [1, self._num_actions])

    # Bellman residual matrix of size [n_data, n_dim].
    a_vec = tf.one_hot(nu_indices, self._dimension) - tf.reduce_sum(
        self._gamma *
        tf.expand_dims(next_target_probs * policy_ratios, axis=-1) *
        tf.one_hot(next_nu_indices, self._dimension),
        axis=1)

    state_action_count = self._get_state_action_counts(env_step)
    # Bellman residual matrix of size [n_dim, n_dim].
    td_mat = tf.einsum('ai, a, aj -> ij', tf.one_hot(nu_indices,
                                                     self._dimension),
                       1.0 / tf.cast(state_action_count, tf.float32), a_vec)

    # Reward vector of size [n_data].
    weighted_rewards = policy_ratio * self._reward_fn(env_step)

    # Reward vector of size [n_dim].
    bias = tf.reduce_sum(
        tf.one_hot(nu_indices, self._dimension) *
        tf.reshape(weighted_rewards, [-1, 1]) * 1.0 /
        tf.cast(state_action_count, tf.float32)[:, None],
        axis=0)

    # Initialize.
    self._nu = np.ones_like(self._nu) * bias[:, None]
    self._nu2 = np.ones_like(self._nu2) * bias[:, None]

    self._a_vec = a_vec
    self._td_mat = td_mat
    self._bias = bias
    self._weighted_rewards = weighted_rewards
    self._state_action_count = state_action_count
    self._nu_indices = nu_indices
    self._initial_nu_indices = initial_nu_indices
    self._initial_target_probs = initial_target_probs
Ejemplo n.º 9
0
  def solve(self,
            dataset: dataset_lib.OffpolicyDataset,
            target_policy: tf_policy.TFPolicy,
            regularizer: float = 1e-8):
    """Solves for Q-values and then approximates target policy value.

    Args:
      dataset: The dataset to sample experience from.
      target_policy: The policy whose value we want to estimate.
      regularizer: A small constant to add before dividing.

    Returns:
      Estimated average per-step reward of the target policy.
    """
    num_estimates = 1 + int(self._num_qvalues)
    transition_matrix = np.zeros(
        [self._dimension, self._dimension, num_estimates])
    reward_vector = np.zeros(
        [self._dimension, num_estimates, self._num_perturbations])
    total_weights = np.zeros([self._dimension, num_estimates])

    episodes, valid_steps = dataset.get_all_episodes(limit=self._limit_episodes)
    #all_rewards = self._reward_fn(episodes)
    #reward_std = np.ma.MaskedArray(all_rewards, valid_steps).std()
    tfagents_episodes = dataset_lib.convert_to_tfagents_timestep(episodes)

    sample_weights = np.array(valid_steps, dtype=np.int64)
    if not self._bootstrap or self._num_qvalues is None:
      sample_weights = (
          sample_weights[:, :, None] * np.ones([1, 1, num_estimates]))
    else:
      probs = np.reshape(sample_weights, [-1]) / np.sum(sample_weights)
      weights = np.random.multinomial(
          np.sum(sample_weights), probs,
          size=self._num_qvalues).astype(np.float32)
      weights = np.reshape(
          np.transpose(weights),
          list(np.shape(sample_weights)) + [self._num_qvalues])
      sample_weights = np.concatenate([sample_weights[:, :, None], weights],
                                      axis=-1)

    for episode_num in range(tf.shape(valid_steps)[0]):
      # Precompute probabilites for this episode.
      this_episode = tf.nest.map_structure(lambda t: t[episode_num], episodes)
      this_tfagents_episode = dataset_lib.convert_to_tfagents_timestep(
          this_episode)
      episode_target_log_probabilities = target_policy.distribution(
          this_tfagents_episode).action.log_prob(this_episode.action)
      episode_target_probs = target_policy.distribution(
          this_tfagents_episode).action.probs_parameter()

      for step_num in range(tf.shape(valid_steps)[1] - 1):
        this_step = tf.nest.map_structure(lambda t: t[episode_num, step_num],
                                          episodes)
        next_step = tf.nest.map_structure(
            lambda t: t[episode_num, step_num + 1], episodes)
        this_tfagents_step = dataset_lib.convert_to_tfagents_timestep(this_step)
        next_tfagents_step = dataset_lib.convert_to_tfagents_timestep(next_step)
        this_weights = sample_weights[episode_num, step_num, :]
        if this_step.is_last() or not valid_steps[episode_num, step_num]:
          continue

        weight = this_weights
        this_index = self._get_index(this_step.observation, this_step.action)

        reward_vector[this_index, :, :] += np.expand_dims(
            self._reward_fn(this_step) * weight, -1)
        if self._num_qvalues is not None:
          random_noise = np.random.binomial(this_weights[1:].astype('int64'),
                                            0.5)
          reward_vector[this_index, 1:, :] += (
              self._perturbation_scale[None, :] *
              (2 * random_noise - this_weights[1:])[:, None])

        total_weights[this_index] += weight

        policy_ratio = 1.0
        if not self._solve_for_state_action_value:
          policy_ratio = tf.exp(episode_target_log_probabilities[step_num] -
                                this_step.get_log_probability())

        # Need to weight next nu by importance weight.
        next_weight = (
            weight if self._solve_for_state_action_value else policy_ratio *
            weight)
        if next_step.is_absorbing():
          next_index = -1  # Absorbing state.
          transition_matrix[this_index, next_index] += next_weight
        else:
          next_probs = episode_target_probs[step_num + 1]
          for next_action, next_prob in enumerate(next_probs):
            next_index = self._get_index(next_step.observation, next_action)
            transition_matrix[this_index, next_index] += next_prob * next_weight
    print('Done processing data.')

    transition_matrix /= (regularizer + total_weights)[:, None, :]
    reward_vector /= (regularizer + total_weights)[:, :, None]
    reward_vector[np.where(np.equal(total_weights,
                                    0.0))] = self._default_reward_value
    reward_vector[-1, :, :] = 0.0  # Terminal absorbing state has 0 reward.

    self._point_qvalues = np.linalg.solve(
        np.eye(self._dimension) - self._gamma * transition_matrix[:, :, 0],
        reward_vector[:, 0])
    if self._num_qvalues is not None:
      self._ensemble_qvalues = np.linalg.solve(
          (np.eye(self._dimension) -
           self._gamma * np.transpose(transition_matrix, [2, 0, 1])),
          np.transpose(reward_vector, [1, 0, 2]))

    return self.estimate_average_reward(dataset, target_policy)
Ejemplo n.º 10
0
    def prepare_dataset(self, dataset: dataset_lib.OffpolicyDataset,
                        target_policy: tf_policy.TFPolicy):
        episodes, valid_steps = dataset.get_all_episodes()
        tfagents_episodes = dataset_lib.convert_to_tfagents_timestep(episodes)

        for episode_num in range(tf.shape(valid_steps)[0]):
            # Precompute probabilites for this episode.
            this_episode = tf.nest.map_structure(lambda t: t[episode_num],
                                                 episodes)
            first_step = tf.nest.map_structure(lambda t: t[0], this_episode)
            this_tfagents_episode = dataset_lib.convert_to_tfagents_timestep(
                this_episode)
            episode_target_log_probabilities = target_policy.distribution(
                this_tfagents_episode).action.log_prob(this_episode.action)
            episode_target_probs = target_policy.distribution(
                this_tfagents_episode).action.probs_parameter()

            for step_num in range(tf.shape(valid_steps)[1] - 1):
                this_step = tf.nest.map_structure(
                    lambda t: t[episode_num, step_num], episodes)
                next_step = tf.nest.map_structure(
                    lambda t: t[episode_num, step_num + 1], episodes)
                if this_step.is_last() or not valid_steps[episode_num,
                                                          step_num]:
                    continue

                weight = 1.0
                nu_index = self._get_index(this_step.observation,
                                           this_step.action)
                self._td_residuals[nu_index, nu_index] += -weight
                self._total_weights[nu_index] += weight

                policy_ratio = 1.0
                if not self._solve_for_state_action_ratio:
                    policy_ratio = tf.exp(
                        episode_target_log_probabilities[step_num] -
                        this_step.get_log_probability())

                # Need to weight next nu by importance weight.
                next_weight = (weight if self._solve_for_state_action_ratio
                               else policy_ratio * weight)
                next_probs = episode_target_probs[step_num + 1]
                for next_action, next_prob in enumerate(next_probs):
                    next_nu_index = self._get_index(next_step.observation,
                                                    next_action)
                    self._td_residuals[next_nu_index,
                                       nu_index] += (next_prob * self._gamma *
                                                     next_weight)

                initial_probs = episode_target_probs[0]
                for initial_action, initial_prob in enumerate(initial_probs):
                    initial_nu_index = self._get_index(first_step.observation,
                                                       initial_action)
                    self._initial_weights[
                        initial_nu_index] += weight * initial_prob

        self._initial_weights = tf.cast(self._initial_weights, tf.float32)
        self._total_weights = tf.cast(self._total_weights, tf.float32)
        self._td_residuals = self._td_residuals / np.sqrt(
            1e-8 + self._total_weights)[None, :]
        self._td_errors = tf.cast(
            np.dot(self._td_residuals, self._td_residuals.T), tf.float32)
        self._td_residuals = tf.cast(self._td_residuals, tf.float32)
Ejemplo n.º 11
0
    def get_is_weighted_reward_samples(self,
                                       dataset: dataset_lib.OffpolicyDataset,
                                       target_policy: tf_policy.TFPolicy,
                                       episode_limit: Optional[int] = None,
                                       eps: Optional[float] = 1e-8):
        """Get the IS weighted reweard samples."""
        episodes, valid_steps = dataset.get_all_episodes(limit=episode_limit)
        total_num_steps_per_episode = tf.shape(valid_steps)[1] - 1
        num_episodes = tf.shape(valid_steps)[0]
        num_samples = num_episodes * total_num_steps_per_episode

        init_env_step = tf.nest.map_structure(lambda t: t[:, 0, ...], episodes)
        env_step = tf.nest.map_structure(
            lambda t: tf.squeeze(
                tf.reshape(t[:, 0:total_num_steps_per_episode, ...],
                           [num_samples, -1])), episodes)
        next_env_step = tf.nest.map_structure(
            lambda t: tf.squeeze(
                tf.reshape(t[:, 1:1 + total_num_steps_per_episode, ...],
                           [num_samples, -1])), episodes)
        tfagents_env_step = dataset_lib.convert_to_tfagents_timestep(env_step)

        gamma_weights = tf.reshape(
            tf.pow(self._gamma, tf.cast(env_step.step_num, tf.float32)),
            [num_episodes, total_num_steps_per_episode])

        rewards = (-self._get_q_value(env_step) + self._reward_fn(env_step) +
                   self._gamma * next_env_step.discount *
                   self._get_v_value(next_env_step, target_policy))
        rewards = tf.reshape(rewards,
                             [num_episodes, total_num_steps_per_episode])

        init_values = self._get_v_value(init_env_step, target_policy)
        init_offset = (1 - self._gamma) * init_values

        target_log_probabilities = target_policy.distribution(
            tfagents_env_step).action.log_prob(env_step.action)
        if tf.rank(target_log_probabilities) > 1:
            target_log_probabilities = tf.reduce_sum(target_log_probabilities,
                                                     -1)
        if self._policy_network is not None:
            baseline_policy_log_probability = self._get_log_prob(
                self._policy_network, env_step)
            if tf.rank(baseline_policy_log_probability) > 1:
                baseline_policy_log_probability = tf.reduce_sum(
                    baseline_policy_log_probability, -1)
            policy_log_ratios = tf.reshape(
                tf.maximum(
                    -1.0 / eps, target_log_probabilities -
                    baseline_policy_log_probability),
                [num_episodes, total_num_steps_per_episode])
        else:
            policy_log_ratios = tf.reshape(
                tf.maximum(
                    -1.0 / eps,
                    target_log_probabilities - env_step.get_log_probability()),
                [num_episodes, total_num_steps_per_episode])
        valid_steps_in = valid_steps[:, 0:total_num_steps_per_episode]
        mask = tf.cast(
            tf.logical_and(valid_steps_in, episodes.discount[:, :-1] > 0.),
            tf.float32)

        masked_rewards = tf.where(mask > 0, rewards, tf.zeros_like(rewards))
        clipped_policy_log_ratios = mask * self.clip_log_factor(
            policy_log_ratios)

        if self._mode in ['trajectory-wise', 'weighted-trajectory-wise']:
            trajectory_avg_rewards = tf.reduce_sum(
                masked_rewards * gamma_weights, axis=1) / tf.reduce_sum(
                    gamma_weights, axis=1)
            trajectory_log_ratios = tf.reduce_sum(clipped_policy_log_ratios,
                                                  axis=1)
            if self._mode == 'trajectory-wise':
                trajectory_avg_rewards *= tf.exp(trajectory_log_ratios)
                return init_offset + trajectory_avg_rewards
            else:
                offset = tf.reduce_max(trajectory_log_ratios)
                normalized_clipped_ratios = tf.exp(trajectory_log_ratios -
                                                   offset)
                normalized_clipped_ratios /= tf.maximum(
                    eps, tf.reduce_mean(normalized_clipped_ratios))
                trajectory_avg_rewards *= normalized_clipped_ratios
                return init_offset + trajectory_avg_rewards

        elif self._mode in ['step-wise', 'weighted-step-wise']:
            trajectory_log_ratios = mask * tf.cumsum(policy_log_ratios, axis=1)
            if self._mode == 'step-wise':
                trajectory_avg_rewards = tf.reduce_sum(
                    masked_rewards * gamma_weights *
                    tf.exp(trajectory_log_ratios),
                    axis=1) / tf.reduce_sum(gamma_weights, axis=1)
                return init_offset + trajectory_avg_rewards
            else:
                # Average over data, for each time step.
                offset = tf.reduce_max(trajectory_log_ratios,
                                       axis=0)  # TODO: Handle mask.
                normalized_imp_weights = tf.exp(trajectory_log_ratios - offset)
                normalized_imp_weights /= tf.maximum(
                    eps,
                    tf.reduce_sum(mask * normalized_imp_weights, axis=0) /
                    tf.maximum(eps, tf.reduce_sum(mask, axis=0)))[None, :]

                trajectory_avg_rewards = tf.reduce_sum(
                    masked_rewards * gamma_weights * normalized_imp_weights,
                    axis=1) / tf.reduce_sum(gamma_weights, axis=1)
                return init_offset + trajectory_avg_rewards
        else:
            ValueError('Estimator is not implemented!')