示例#1
0
 def _populate_policy_info_spec(self, context_spec):
     predicted_rewards_mean = ()
     if (policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
             in self._emit_policy_info):
         predicted_rewards_mean = tensor_spec.TensorSpec(
             [self._num_actions], dtype=self._dtype)
     predicted_rewards_sampled = ()
     if (policy_utilities.InfoFields.PREDICTED_REWARDS_SAMPLED
             in self._emit_policy_info):
         predicted_rewards_sampled = tensor_spec.TensorSpec(
             [self._num_actions], dtype=self._dtype)
     if self._accepts_per_arm_features:
         # The features for the chosen arm is saved to policy_info.
         arm_spec = context_spec[bandit_spec_utils.PER_ARM_FEATURE_KEY]
         chosen_arm_features_info = tensor_spec.TensorSpec(
             dtype=arm_spec.dtype,
             shape=arm_spec.shape[1:],
             name='chosen_arm_features')
         info_spec = policy_utilities.PerArmPolicyInfo(
             predicted_rewards_mean=predicted_rewards_mean,
             predicted_rewards_sampled=predicted_rewards_sampled,
             chosen_arm_features=chosen_arm_features_info)
     else:
         info_spec = policy_utilities.PolicyInfo(
             predicted_rewards_mean=predicted_rewards_mean,
             predicted_rewards_sampled=predicted_rewards_sampled)
     return info_spec
示例#2
0
 def _populate_policy_info_spec(self, observation_spec,
                                observation_and_action_constraint_splitter):
     predicted_rewards_mean = ()
     if (policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
             in self._emit_policy_info):
         predicted_rewards_mean = tensor_spec.TensorSpec(
             [self._num_actions], dtype=self._dtype)
     predicted_rewards_sampled = ()
     if (policy_utilities.InfoFields.PREDICTED_REWARDS_SAMPLED
             in self._emit_policy_info):
         predicted_rewards_sampled = tensor_spec.TensorSpec(
             [self._num_actions], dtype=self._dtype)
     if self._accepts_per_arm_features:
         # The features for the chosen arm is saved to policy_info.
         chosen_arm_features_info = (
             policy_utilities.create_chosen_arm_features_info_spec(
                 observation_spec))
         info_spec = policy_utilities.PerArmPolicyInfo(
             predicted_rewards_mean=predicted_rewards_mean,
             predicted_rewards_sampled=predicted_rewards_sampled,
             chosen_arm_features=chosen_arm_features_info)
     else:
         info_spec = policy_utilities.PolicyInfo(
             predicted_rewards_mean=predicted_rewards_mean,
             predicted_rewards_sampled=predicted_rewards_sampled)
     return info_spec
    def _distribution(self, time_step, policy_state):
        observation = time_step.observation
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)
        if observation_and_action_constraint_splitter is not None:
            observation, mask = observation_and_action_constraint_splitter(
                observation)
        predicted_reward_values, policy_state = self._reward_network(
            observation, time_step.step_type, policy_state)
        predicted_reward_values.shape.with_rank_at_least(2)
        predicted_reward_values.shape.with_rank_at_most(3)
        if predicted_reward_values.shape[-1] != self._expected_num_actions:
            raise ValueError(
                'The number of actions ({}) does not match the reward_network output'
                ' size ({}.)'.format(self._expected_num_actions,
                                     predicted_reward_values.shape[1]))
        if observation_and_action_constraint_splitter is not None:
            actions = policy_utilities.masked_argmax(
                predicted_reward_values,
                mask,
                output_type=self.action_spec.dtype)
        else:
            actions = tf.argmax(predicted_reward_values,
                                axis=-1,
                                output_type=self.action_spec.dtype)
        actions += self._action_offset

        policy_info = policy_utilities.PolicyInfo(predicted_rewards_mean=(
            predicted_reward_values if policy_utilities.InfoFields.
            PREDICTED_REWARDS_MEAN in self._emit_policy_info else ()))

        return policy_step.PolicyStep(
            tfp.distributions.Deterministic(loc=actions), policy_state,
            policy_info)
示例#4
0
  def _action(self, time_step, policy_state, seed):
    observation = time_step.observation
    mask = None
    observation_and_action_constraint_splitter = (
        self.observation_and_action_constraint_splitter)
    if observation_and_action_constraint_splitter is not None:
      observation, mask = observation_and_action_constraint_splitter(
          observation)
    # Check the shape of the observation matrix.
    if not observation.shape.is_compatible_with([None, self._context_dim]):
      raise ValueError('Observation shape is expected to be {}. Got {}.'.format(
          [None, self._context_dim], observation.shape.as_list()))

    observation = tf.cast(observation, dtype=self._dtype)

    # Pass the observations through the encoding network.
    encoded_observation, _ = self._encoding_network(observation)

    chosen_actions, est_mean_rewards = tf.cond(
        self._actions_from_reward_layer,
        lambda: self._get_actions_from_reward_layer(encoded_observation, mask),
        lambda: self._get_actions_from_linucb(encoded_observation, mask))

    policy_info = policy_utilities.PolicyInfo(
        predicted_rewards_mean=(
            est_mean_rewards if
            policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in
            self._emit_policy_info else ()))
    return policy_step.PolicyStep(chosen_actions, policy_state, policy_info)
示例#5
0
 def _populate_policy_info(self, arm_observations, chosen_actions,
                           rewards_for_argmax, est_rewards):
     if self._accepts_per_arm_features:
         # Saving the features for the chosen action to the policy_info.
         chosen_arm_features = tf.gather(params=arm_observations,
                                         indices=chosen_actions,
                                         batch_dims=1)
         policy_info = policy_utilities.PerArmPolicyInfo(
             predicted_rewards_sampled=(
                 rewards_for_argmax
                 if policy_utilities.InfoFields.PREDICTED_REWARDS_SAMPLED
                 in self._emit_policy_info else ()),
             predicted_rewards_mean=(
                 tf.stack(est_rewards, axis=-1)
                 if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                 in self._emit_policy_info else ()),
             chosen_arm_features=chosen_arm_features)
     else:
         policy_info = policy_utilities.PolicyInfo(
             predicted_rewards_sampled=(
                 rewards_for_argmax
                 if policy_utilities.InfoFields.PREDICTED_REWARDS_SAMPLED
                 in self._emit_policy_info else ()),
             predicted_rewards_mean=(
                 tf.stack(est_rewards, axis=-1)
                 if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                 in self._emit_policy_info else ()))
     return policy_info
示例#6
0
def _get_action_step(action, num_agents, num_actions):
    batch_size = tf.shape(action)[0]
    choices = tf.constant(num_agents - 1, shape=action.shape, dtype=tf.int32)
    return policy_step.PolicyStep(
        action=tf.convert_to_tensor(action),
        info={
            mixture_policy.MIXTURE_AGENT_ID:
            choices,
            mixture_policy.SUBPOLICY_INFO:
            policy_utilities.PolicyInfo(
                predicted_rewards_mean=tf.zeros([batch_size, num_actions]))
        })
示例#7
0
    def _distribution(self, time_step, policy_state):
        observation = time_step.observation
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)
        if observation_and_action_constraint_splitter is not None:
            observation, mask = observation_and_action_constraint_splitter(
                observation)
        batch_size = tf.shape(observation)[0]

        predictions, policy_state = self._reward_network(
            observation, time_step.step_type, policy_state)

        if isinstance(self._reward_network,
                      heteroscedastic_q_network.HeteroscedasticQNetwork):
            predicted_reward_values = predictions.q_value_logits
        else:
            predicted_reward_values = predictions

        predicted_reward_values.shape.with_rank_at_least(2)
        predicted_reward_values.shape.with_rank_at_most(3)
        if predicted_reward_values.shape[-1] != self._expected_num_actions:
            raise ValueError(
                'The number of actions ({}) does not match the reward_network output'
                ' size ({}.)'.format(self._expected_num_actions,
                                     predicted_reward_values.shape[1]))
        if observation_and_action_constraint_splitter is not None:
            actions = policy_utilities.masked_argmax(
                predicted_reward_values,
                mask,
                output_type=self.action_spec.dtype)
        else:
            actions = tf.argmax(predicted_reward_values,
                                axis=-1,
                                output_type=self.action_spec.dtype)
        actions += self._action_offset

        bandit_policy_values = tf.fill(
            [batch_size, 1], policy_utilities.BanditPolicyType.GREEDY)

        policy_info = policy_utilities.PolicyInfo(
            predicted_rewards_mean=(
                predicted_reward_values
                if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                in self._emit_policy_info else ()),
            bandit_policy_type=(bandit_policy_values if
                                policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                                in self._emit_policy_info else ()))

        return policy_step.PolicyStep(
            tfp.distributions.Deterministic(loc=actions), policy_state,
            policy_info)
    def _action(self, time_step, policy_state, seed):
        seed_stream = tfp.util.SeedStream(seed=seed, salt='ts_policy')
        observation = time_step.observation
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)
        if observation_and_action_constraint_splitter is not None:
            observation, mask = observation_and_action_constraint_splitter(
                observation)

        observation = tf.cast(observation,
                              dtype=self._parameter_estimators[0].dtype)
        mean_estimates, scales = _get_means_and_variances(
            self._parameter_estimators, self._weight_covariance_matrices,
            observation)
        mu_sampler = tfd.Normal(loc=tf.stack(mean_estimates, axis=-1),
                                scale=tf.sqrt(tf.stack(scales, axis=-1)))
        reward_samples = mu_sampler.sample(seed=seed_stream())
        if observation_and_action_constraint_splitter is not None:
            actions = policy_utilities.masked_argmax(
                reward_samples, mask, output_type=self._action_spec.dtype)
        else:
            actions = tf.argmax(reward_samples,
                                axis=-1,
                                output_type=self._action_spec.dtype)

        policy_info = policy_utilities.PolicyInfo(
            predicted_rewards_sampled=(
                reward_samples
                if policy_utilities.InfoFields.PREDICTED_REWARDS_SAMPLED
                in self._emit_policy_info else ()),
            predicted_rewards_mean=(tf.stack(
                mean_estimates,
                axis=-1) if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                                    in self._emit_policy_info else ()))

        return policy_step.PolicyStep(actions, policy_state, policy_info)
    def __init__(self,
                 time_step_spec=None,
                 action_spec=None,
                 reward_network=None,
                 observation_and_action_constraint_splitter=None,
                 emit_policy_info=(),
                 name=None):
        """Builds a GreedyRewardPredictionPolicy given a reward tf_agents.Network.

    This policy takes a tf_agents.Network predicting rewards and generates the
    action corresponding to the largest predicted reward.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      reward_network: An instance of a `tf_agents.network.Network`,
        callable via `network(observation, step_type) -> (output, final_state)`.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the network and 2) the
        mask.  The mask should be a 0-1 `Tensor` of shape
        `[batch_size, num_actions]`. This function should also work with a
        `TensorSpec` as input, and should output `TensorSpec` objects for the
        observation and mask.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      name: The name of this policy. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      NotImplementedError: If `action_spec` contains more than one
        `BoundedTensorSpec` or the `BoundedTensorSpec` is not valid.
    """
        flat_action_spec = tf.nest.flatten(action_spec)
        if len(flat_action_spec) > 1:
            raise NotImplementedError(
                'action_spec can only contain a single BoundedTensorSpec.')

        action_spec = flat_action_spec[0]
        if (not tensor_spec.is_bounded(action_spec)
                or not tensor_spec.is_discrete(action_spec)
                or action_spec.shape.rank > 1
                or action_spec.shape.num_elements() != 1):
            raise NotImplementedError(
                'action_spec must be a BoundedTensorSpec of type int32 and shape (). '
                'Found {}.'.format(action_spec))
        self._expected_num_actions = action_spec.maximum - action_spec.minimum + 1
        self._action_offset = action_spec.minimum
        reward_network.create_variables()
        self._reward_network = reward_network

        self._emit_policy_info = emit_policy_info
        predicted_rewards_mean = ()
        if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in emit_policy_info:
            predicted_rewards_mean = tensor_spec.TensorSpec(
                [self._expected_num_actions])
        bandit_policy_type = ()
        if policy_utilities.InfoFields.BANDIT_POLICY_TYPE in emit_policy_info:
            bandit_policy_type = (
                policy_utilities.create_bandit_policy_type_tensor_spec(
                    shape=[1]))
        info_spec = policy_utilities.PolicyInfo(
            predicted_rewards_mean=predicted_rewards_mean,
            bandit_policy_type=bandit_policy_type)

        super(GreedyRewardPredictionPolicy,
              self).__init__(time_step_spec,
                             action_spec,
                             policy_state_spec=reward_network.state_spec,
                             clip=False,
                             info_spec=info_spec,
                             observation_and_action_constraint_splitter=(
                                 observation_and_action_constraint_splitter),
                             name=name)
示例#10
0
    def _distribution(self, time_step, policy_state):
        observation = time_step.observation
        if self.observation_and_action_constraint_splitter is not None:
            observation, _ = self.observation_and_action_constraint_splitter(
                observation)

        predictions, policy_state = self._reward_network(
            observation, time_step.step_type, policy_state)
        batch_size = tf.shape(predictions)[0]

        if isinstance(self._reward_network,
                      heteroscedastic_q_network.HeteroscedasticQNetwork):
            predicted_reward_values = predictions.q_value_logits
        else:
            predicted_reward_values = predictions

        predicted_reward_values.shape.with_rank_at_least(2)
        predicted_reward_values.shape.with_rank_at_most(3)
        if predicted_reward_values.shape[
                -1] is not None and predicted_reward_values.shape[
                    -1] != self._expected_num_actions:
            raise ValueError(
                'The number of actions ({}) does not match the reward_network output'
                ' size ({}).'.format(self._expected_num_actions,
                                     predicted_reward_values.shape[1]))

        mask = constr.construct_mask_from_multiple_sources(
            time_step.observation,
            self._observation_and_action_constraint_splitter,
            self._constraints, self._expected_num_actions)

        # Argmax.
        if mask is not None:
            actions = policy_utilities.masked_argmax(
                predicted_reward_values,
                mask,
                output_type=self.action_spec.dtype)
        else:
            actions = tf.argmax(predicted_reward_values,
                                axis=-1,
                                output_type=self.action_spec.dtype)

        actions += self._action_offset

        bandit_policy_values = tf.fill(
            [batch_size, 1], policy_utilities.BanditPolicyType.GREEDY)

        if self._accepts_per_arm_features:
            # Saving the features for the chosen action to the policy_info.
            def gather_observation(obs):
                return tf.gather(params=obs, indices=actions, batch_dims=1)

            chosen_arm_features = tf.nest.map_structure(
                gather_observation,
                observation[bandit_spec_utils.PER_ARM_FEATURE_KEY])
            policy_info = policy_utilities.PerArmPolicyInfo(
                log_probability=tf.zeros([batch_size], tf.float32)
                if policy_utilities.InfoFields.LOG_PROBABILITY
                in self._emit_policy_info else (),
                predicted_rewards_mean=(
                    predicted_reward_values
                    if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                    in self._emit_policy_info else ()),
                bandit_policy_type=(
                    bandit_policy_values
                    if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                    in self._emit_policy_info else ()),
                chosen_arm_features=chosen_arm_features)
        else:
            policy_info = policy_utilities.PolicyInfo(
                log_probability=tf.zeros([batch_size], tf.float32)
                if policy_utilities.InfoFields.LOG_PROBABILITY
                in self._emit_policy_info else (),
                predicted_rewards_mean=(
                    predicted_reward_values
                    if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                    in self._emit_policy_info else ()),
                bandit_policy_type=(
                    bandit_policy_values
                    if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                    in self._emit_policy_info else ()))

        return policy_step.PolicyStep(
            tfp.distributions.Deterministic(loc=actions), policy_state,
            policy_info)
    def __init__(
            self,
            time_step_spec: Optional[ts.TimeStep],
            action_spec: Optional[NestedBoundedTensorSpec],
            scalarizer: multi_objective_scalarizer.Scalarizer,
            objective_networks: Sequence[Network],
            observation_and_action_constraint_splitter: types.Splitter = None,
            accepts_per_arm_features: bool = False,
            emit_policy_info: Tuple[Text] = (),
            name: Optional[Text] = None):
        """Builds a GreedyMultiObjectiveNeuralPolicy based on multiple networks.

    This policy takes an iterable of `tf_agents.Network`, each responsible for
    predicting a specific objective, along with a `Scalarizer` object to
    generate an action by maximizing the scalarized objective, i.e., the output
    of the `Scalarizer` applied to the multiple predicted objectives by the
    networks.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of `BoundedTensorSpec` representing the actions.
      scalarizer: A
       `tf_agents.bandits.multi_objective.multi_objective_scalarizer.Scalarizer`
        object that implements scalarization of multiple objectives into a
        single scalar reward.
      objective_networks: A Sequence of `tf_agents.network.Network` objects to
        be used by the policy. Each network will be called with
        call(observation, step_type) and is expected to provide a prediction for
        a specific objective for all actions.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the network and 2) the
        mask.  The mask should be a 0-1 `Tensor` of shape `[batch_size,
        num_actions]`. This function should also work with a `TensorSpec` as
        input, and should output `TensorSpec` objects for the observation and
        mask.
      accepts_per_arm_features: (bool) Whether the policy accepts per-arm
        features.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      name: The name of this policy. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      NotImplementedError: If `action_spec` contains more than one
        `BoundedTensorSpec` or the `BoundedTensorSpec` is not valid.
      NotImplementedError: If `action_spec` is not a `BoundedTensorSpec` of type
        int32 and shape ().
      ValueError: If `objective_networks` has fewer than two networks.
      ValueError: If `accepts_per_arm_features` is true but `time_step_spec` is
        None.
    """
        flat_action_spec = tf.nest.flatten(action_spec)
        if len(flat_action_spec) > 1:
            raise NotImplementedError(
                'action_spec can only contain a single BoundedTensorSpec.')

        action_spec = flat_action_spec[0]
        if (not tensor_spec.is_bounded(action_spec)
                or not tensor_spec.is_discrete(action_spec)
                or action_spec.shape.rank > 1
                or action_spec.shape.num_elements() != 1):
            raise NotImplementedError(
                'action_spec must be a BoundedTensorSpec of type int32 and shape (). '
                'Found {}.'.format(action_spec))
        self._expected_num_actions = action_spec.maximum - action_spec.minimum + 1
        self._action_offset = action_spec.minimum
        policy_state_spec = []
        for network in objective_networks:
            policy_state_spec.append(network.state_spec)
            network.create_variables()
        self._objective_networks = objective_networks
        self._scalarizer = scalarizer
        self._num_objectives = len(self._objective_networks)
        if self._num_objectives < 2:
            raise ValueError(
                'Number of objectives should be at least two, but found to be {}'
                .format(self._num_objectives))

        self._emit_policy_info = emit_policy_info
        predicted_rewards_mean = ()
        if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in emit_policy_info:
            predicted_rewards_mean = tensor_spec.TensorSpec(
                [self._num_objectives, self._expected_num_actions])
        bandit_policy_type = ()
        if policy_utilities.InfoFields.BANDIT_POLICY_TYPE in emit_policy_info:
            bandit_policy_type = (
                policy_utilities.create_bandit_policy_type_tensor_spec(
                    shape=[1]))
        if accepts_per_arm_features:
            if time_step_spec is None:
                raise ValueError(
                    'time_step_spec should not be None for per-arm-features policies, '
                    'but found to be.')
            # The features for the chosen arm is saved to policy_info.
            chosen_arm_features_info = (
                policy_utilities.create_chosen_arm_features_info_spec(
                    time_step_spec.observation,
                    observation_and_action_constraint_splitter))
            info_spec = policy_utilities.PerArmPolicyInfo(
                predicted_rewards_mean=predicted_rewards_mean,
                bandit_policy_type=bandit_policy_type,
                chosen_arm_features=chosen_arm_features_info)
        else:
            info_spec = policy_utilities.PolicyInfo(
                predicted_rewards_mean=predicted_rewards_mean,
                bandit_policy_type=bandit_policy_type)

        self._accepts_per_arm_features = accepts_per_arm_features

        super(GreedyMultiObjectiveNeuralPolicy,
              self).__init__(time_step_spec,
                             action_spec,
                             policy_state_spec=policy_state_spec,
                             clip=False,
                             info_spec=info_spec,
                             emit_log_probability='log_probability'
                             in emit_policy_info,
                             observation_and_action_constraint_splitter=(
                                 observation_and_action_constraint_splitter),
                             name=name)
    def __init__(self,
                 action_spec,
                 time_step_spec,
                 weight_covariance_matrices,
                 parameter_estimators,
                 observation_and_action_constraint_splitter=None,
                 emit_policy_info=(),
                 name=None):
        """Initializes `LinearThompsonSamplingPolicy`.

    The `weight_covariance_matrices` and `parameter_estimators`
      arguments may either be `Tensor`s or `tf.Variable`s. If they are
      variables, then any assignment to those variables will be reflected in the
      output of the policy.

    Args:
      action_spec: Array spec containing action specification.
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      weight_covariance_matrices: A list of `B` inverse matrices from the paper.
        The list has `num_actions` elements of shape
        `[context_dim, context_dim]`.
      parameter_estimators: List of `f` vectors from the paper. The list has
        `num_actions' elements of shape is `[context_dim]`.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the bandit policy and 2)
        the mask. The mask should be a 0-1 `Tensor` of shape
        `[batch_size, num_actions]`. This function should also work with a
        `TensorSpec` as input, and should output `TensorSpec` objects for the
        observation and mask.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      name: The name of this policy.
    """
        if not isinstance(weight_covariance_matrices, (list, tuple)):
            raise ValueError(
                'weight_covariances must be a list of matrices (Tensors).')
        self._weight_covariance_matrices = weight_covariance_matrices

        if not isinstance(parameter_estimators, (list, tuple)):
            raise ValueError(
                'parameter_estimators must be a list of vectors (Tensors).')
        self._parameter_estimators = parameter_estimators

        self._action_spec = action_spec
        self._num_actions = action_spec.maximum + 1
        if observation_and_action_constraint_splitter is not None:
            context_shape = observation_and_action_constraint_splitter(
                time_step_spec.observation)[0].shape.as_list()
        else:
            context_shape = time_step_spec.observation.shape.as_list()

        self._context_dim = (tf.compat.dimension_value(context_shape[0])
                             if context_shape else 1)
        self._variables = [
            x for x in weight_covariance_matrices + parameter_estimators
            if isinstance(x, tf.Variable)
        ]
        for t in self._parameter_estimators:
            _assert_shape([self._context_dim], t.shape.as_list(),
                          'Parameter estimators')
        for t in self._weight_covariance_matrices:
            _assert_shape([self._context_dim, self._context_dim],
                          t.shape.as_list(), 'Weight covariance')

        self._emit_policy_info = emit_policy_info
        self._dtype = self._weight_covariance_matrices[0].dtype
        predicted_rewards_sampled = ()
        if (policy_utilities.InfoFields.PREDICTED_REWARDS_SAMPLED
                in emit_policy_info):
            predicted_rewards_sampled = tensor_spec.TensorSpec(
                [self._num_actions], dtype=self._dtype)
        predicted_rewards_mean = ()
        if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in emit_policy_info:
            predicted_rewards_mean = tensor_spec.TensorSpec(
                [self._num_actions], dtype=self._dtype)
        info_spec = policy_utilities.PolicyInfo(
            predicted_rewards_sampled=predicted_rewards_sampled,
            predicted_rewards_mean=predicted_rewards_mean)

        super(LinearThompsonSamplingPolicy,
              self).__init__(time_step_spec=time_step_spec,
                             action_spec=action_spec,
                             observation_and_action_constraint_splitter=(
                                 observation_and_action_constraint_splitter),
                             info_spec=info_spec)
示例#13
0
    def __init__(self,
                 encoding_network,
                 encoding_dim,
                 reward_layer,
                 epsilon_greedy,
                 actions_from_reward_layer,
                 cov_matrix,
                 data_vector,
                 num_samples,
                 time_step_spec=None,
                 alpha=1.0,
                 emit_policy_info=(),
                 emit_log_probability=False,
                 accepts_per_arm_features=False,
                 observation_and_action_constraint_splitter=None,
                 name=None):
        """Initializes `NeuralLinUCBPolicy`.

    Args:
      encoding_network: network that encodes the observations.
      encoding_dim: (int) dimension of the encoded observations.
      reward_layer: final layer that predicts the expected reward per arm. In
        case the policy accepts per-arm features, the output of this layer has
        to be a scalar. This is because in the per-arm case, all encoded
        observations have to go through the same computation to get the reward
        estimates. The `num_actions` dimension of the encoded observation is
        treated as a batch dimension in the reward layer.
      epsilon_greedy: (float) representing the probability of choosing a random
        action instead of the greedy action.
      actions_from_reward_layer: (bool) whether to get actions from the reward
        layer or from LinUCB.
      cov_matrix: list of the covariance matrices. There exists one covariance
        matrix per arm, unless the policy accepts per-arm features, in which
        case this list must have a single element.
      data_vector: list of the data vectors. A data vector is a weighted sum
        of the observations, where the weight is the corresponding reward. Each
        arm has its own data vector, unless the policy accepts per-arm features,
        in which case this list must have a single element.
      num_samples: list of number of samples per arm. If the policy accepts per-
        arm features, this is a single-element list counting the number of
        steps.
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      alpha: (float) non-negative weight multiplying the confidence intervals.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      emit_log_probability: (bool) whether to emit log probabilities.
      accepts_per_arm_features: (bool) Whether the policy accepts per-arm
        features.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the bandit policy and 2)
        the mask. The mask should be a 0-1 `Tensor` of shape
        `[batch_size, num_actions]`. This function should also work with a
        `TensorSpec` as input, and should output `TensorSpec` objects for the
        observation and mask.
      name: The name of this policy.
    """
        encoding_network.create_variables()
        self._encoding_network = encoding_network
        self._reward_layer = reward_layer
        self._encoding_dim = encoding_dim

        if accepts_per_arm_features and reward_layer.units != 1:
            raise ValueError(
                'The output dimension of the reward layer must be 1, got'
                ' {}'.format(reward_layer.units))

        if not isinstance(cov_matrix, (list, tuple)):
            raise ValueError(
                'cov_matrix must be a list of matrices (Tensors).')
        self._cov_matrix = cov_matrix

        if not isinstance(data_vector, (list, tuple)):
            raise ValueError(
                'data_vector must be a list of vectors (Tensors).')
        self._data_vector = data_vector

        if not isinstance(num_samples, (list, tuple)):
            raise ValueError(
                'num_samples must be a list of vectors (Tensors).')
        self._num_samples = num_samples

        self._alpha = alpha
        self._actions_from_reward_layer = actions_from_reward_layer
        self._epsilon_greedy = epsilon_greedy
        self._dtype = self._data_vector[0].dtype

        if len(cov_matrix) != len(data_vector):
            raise ValueError(
                'The size of list cov_matrix must match the size of '
                'list data_vector. Got {} for cov_matrix and {} '
                'for data_vector'.format(len(self._cov_matrix),
                                         len((data_vector))))
        if len(num_samples) != len(cov_matrix):
            raise ValueError('The size of num_samples must match the size of '
                             'list cov_matrix. Got {} for num_samples and {} '
                             'for cov_matrix'.format(len(self._num_samples),
                                                     len((cov_matrix))))

        self._accepts_per_arm_features = accepts_per_arm_features
        if observation_and_action_constraint_splitter is not None:
            context_spec, _ = observation_and_action_constraint_splitter(
                time_step_spec.observation)
        else:
            context_spec = time_step_spec.observation
        if accepts_per_arm_features:
            self._num_actions = context_spec[
                bandit_spec_utils.PER_ARM_FEATURE_KEY].shape.as_list()[0]
            self._num_models = 1
        else:
            self._num_actions = len(cov_matrix)
            self._num_models = self._num_actions
        (self._global_context_dim,
         self._arm_context_dim) = bandit_spec_utils.get_context_dims_from_spec(
             context_spec, accepts_per_arm_features)
        cov_matrix_dim = tf.compat.dimension_value(cov_matrix[0].shape[0])
        if self._encoding_dim != cov_matrix_dim:
            raise ValueError('The dimension of matrix `cov_matrix` must match '
                             'encoding dimension {}.'
                             'Got {} for `cov_matrix`.'.format(
                                 self._encoding_dim, cov_matrix_dim))
        data_vector_dim = tf.compat.dimension_value(data_vector[0].shape[0])
        if self._encoding_dim != data_vector_dim:
            raise ValueError(
                'The dimension of vector `data_vector` must match '
                'encoding  dimension {}. '
                'Got {} for `data_vector`.'.format(self._encoding_dim,
                                                   data_vector_dim))
        action_spec = tensor_spec.BoundedTensorSpec(shape=(),
                                                    dtype=tf.int32,
                                                    minimum=0,
                                                    maximum=self._num_actions -
                                                    1,
                                                    name='action')

        self._emit_policy_info = emit_policy_info
        predicted_rewards_mean = ()
        if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in emit_policy_info:
            predicted_rewards_mean = tensor_spec.TensorSpec(
                [self._num_actions], dtype=tf.float32)
        if accepts_per_arm_features:
            chosen_arm_features_info = tensor_spec.TensorSpec(
                dtype=tf.float32,
                shape=[self._arm_context_dim],
                name='chosen_arm_features')
            info_spec = policy_utilities.PerArmPolicyInfo(
                predicted_rewards_mean=predicted_rewards_mean,
                chosen_arm_features=chosen_arm_features_info)
        else:
            info_spec = policy_utilities.PolicyInfo(
                predicted_rewards_mean=predicted_rewards_mean)

        super(NeuralLinUCBPolicy,
              self).__init__(time_step_spec=time_step_spec,
                             action_spec=action_spec,
                             emit_log_probability=emit_log_probability,
                             observation_and_action_constraint_splitter=(
                                 observation_and_action_constraint_splitter),
                             info_spec=info_spec,
                             name=name)
示例#14
0
    def __init__(self,
                 action_spec,
                 cov_matrix,
                 data_vector,
                 num_samples,
                 time_step_spec=None,
                 exploration_strategy=ExplorationStrategy.optimistic,
                 alpha=1.0,
                 eig_vals=(),
                 eig_matrix=(),
                 tikhonov_weight=1.0,
                 add_bias=False,
                 emit_policy_info=(),
                 emit_log_probability=False,
                 observation_and_action_constraint_splitter=None,
                 name=None):
        """Initializes `LinearBanditPolicy`.

    The `a` and `b` arguments may be either `Tensor`s or `tf.Variable`s.
    If they are variables, then any assignements to those variables will be
    reflected in the output of the policy.

    Args:
      action_spec: `TensorSpec` containing action specification.
      cov_matrix: list of the covariance matrices A in the paper. There exists
        one A matrix per arm.
      data_vector: list of the b vectors in the paper. The b vector is a
        weighted sum of the observations, where the weight is the corresponding
        reward. Each arm has its own vector b.
      num_samples: list of number of samples per arm.
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      exploration_strategy: An Enum of type ExplortionStrategy. The strategy
        used for choosing the actions to incorporate exploration. Currently
        supported strategies are `optimistic` and `sampling`.
      alpha: a float value used to scale the confidence intervals.
      eig_vals: list of eigenvalues for each covariance matrix (one per arm).
      eig_matrix: list of eigenvectors for each covariance matrix (one per arm).
      tikhonov_weight: (float) tikhonov regularization term.
      add_bias: If true, a bias term will be added to the linear reward
        estimation.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      emit_log_probability: Whether to emit log probabilities.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the bandit policy and 2)
        the mask. The mask should be a 0-1 `Tensor` of shape `[batch_size,
        num_actions]`. This function should also work with a `TensorSpec` as
        input, and should output `TensorSpec` objects for the observation and
        mask.
      name: The name of this policy.
    """
        if not isinstance(cov_matrix, (list, tuple)):
            raise ValueError(
                'cov_matrix must be a list of matrices (Tensors).')
        self._cov_matrix = cov_matrix

        if not isinstance(data_vector, (list, tuple)):
            raise ValueError(
                'data_vector must be a list of vectors (Tensors).')
        self._data_vector = data_vector

        if not isinstance(num_samples, (list, tuple)):
            raise ValueError(
                'num_samples must be a list of vectors (Tensors).')
        self._num_samples = num_samples

        if not isinstance(eig_vals, (list, tuple)):
            raise ValueError('eig_vals must be a list of vectors (Tensors).')
        self._eig_vals = eig_vals

        if not isinstance(eig_matrix, (list, tuple)):
            raise ValueError('eig_matrix must be a list of vectors (Tensors).')
        self._eig_matrix = eig_matrix

        self._exploration_strategy = exploration_strategy
        if exploration_strategy == ExplorationStrategy.sampling:
            # We do not have a way to calculate log probabilities for TS yet.
            emit_log_probability = False

        self._alpha = alpha
        self._use_eigendecomp = False
        if eig_matrix:
            self._use_eigendecomp = True
        self._tikhonov_weight = tikhonov_weight
        self._add_bias = add_bias

        if len(cov_matrix) != len(data_vector):
            raise ValueError(
                'The size of list cov_matrix must match the size of '
                'list data_vector. Got {} for cov_matrix and {} '
                'for data_vector'.format(len(self._cov_matrix),
                                         len((data_vector))))
        if len(num_samples) != len(cov_matrix):
            raise ValueError('The size of num_samples must match the size of '
                             'list cov_matrix. Got {} for num_samples and {} '
                             'for cov_matrix'.format(len(self._num_samples),
                                                     len((cov_matrix))))
        if tf.nest.is_nested(action_spec):
            raise ValueError('Nested `action_spec` is not supported.')

        self._num_actions = action_spec.maximum + 1
        if self._num_actions != len(cov_matrix):
            raise ValueError(
                'The number of elements in `cov_matrix` ({}) must match '
                'the number of actions derived from `action_spec` ({}).'.
                format(len(cov_matrix), self._num_actions))
        if observation_and_action_constraint_splitter is not None:
            context_shape = observation_and_action_constraint_splitter(
                time_step_spec.observation)[0].shape.as_list()
        else:
            context_shape = time_step_spec.observation.shape.as_list()
        self._context_dim = (tf.compat.dimension_value(context_shape[0])
                             if context_shape else 1)
        if self._add_bias:
            # The bias is added via a constant 1 feature.
            self._context_dim += 1
        cov_matrix_dim = tf.compat.dimension_value(cov_matrix[0].shape[0])
        if self._context_dim != cov_matrix_dim:
            raise ValueError('The dimension of matrix `cov_matrix` must match '
                             'context dimension {}.'
                             'Got {} for `cov_matrix`.'.format(
                                 self._context_dim, cov_matrix_dim))

        data_vector_dim = tf.compat.dimension_value(data_vector[0].shape[0])
        if self._context_dim != data_vector_dim:
            raise ValueError(
                'The dimension of vector `data_vector` must match '
                'context  dimension {}. '
                'Got {} for `data_vector`.'.format(self._context_dim,
                                                   data_vector_dim))

        self._dtype = self._data_vector[0].dtype
        self._emit_policy_info = emit_policy_info
        predicted_rewards_mean = ()
        if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in emit_policy_info:
            predicted_rewards_mean = tensor_spec.TensorSpec(
                [self._num_actions], dtype=self._dtype)
        predicted_rewards_sampled = ()
        if (policy_utilities.InfoFields.PREDICTED_REWARDS_SAMPLED
                in emit_policy_info):
            predicted_rewards_sampled = tensor_spec.TensorSpec(
                [self._num_actions], dtype=self._dtype)
        info_spec = policy_utilities.PolicyInfo(
            predicted_rewards_mean=predicted_rewards_mean,
            predicted_rewards_sampled=predicted_rewards_sampled)

        super(LinearBanditPolicy,
              self).__init__(time_step_spec=time_step_spec,
                             action_spec=action_spec,
                             info_spec=info_spec,
                             emit_log_probability=emit_log_probability,
                             observation_and_action_constraint_splitter=(
                                 observation_and_action_constraint_splitter),
                             name=name)
示例#15
0
    def _distribution(self, time_step, policy_state):
        observation = time_step.observation
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)
        if observation_and_action_constraint_splitter is not None:
            observation, mask = observation_and_action_constraint_splitter(
                observation)
        observation = tf.cast(observation, dtype=self._dtype)
        if self._add_bias:
            # The bias is added via a constant 1 feature.
            observation = tf.concat([
                observation,
                tf.ones([tf.shape(observation)[0], 1], dtype=self._dtype)
            ],
                                    axis=1)
        # Check the shape of the observation matrix. The observations can be
        # batched.
        if not observation.shape.is_compatible_with([None, self._context_dim]):
            raise ValueError(
                'Observation shape is expected to be {}. Got {}.'.format(
                    [None, self._context_dim], observation.shape.as_list()))
        observation = tf.reshape(observation, [-1, self._context_dim])

        est_rewards = []
        confidence_intervals = []
        for k in range(self._num_actions):
            if self._use_eigendecomp:
                q_t_b = tf.matmul(self._eig_matrix[k],
                                  tf.linalg.matrix_transpose(observation),
                                  transpose_a=True)
                lambda_inv = tf.divide(
                    tf.ones_like(self._eig_vals[k]),
                    self._eig_vals[k] + self._tikhonov_weight)
                a_inv_x = tf.matmul(self._eig_matrix[k],
                                    tf.einsum('j,jk->jk', lambda_inv, q_t_b))
            else:
                a_inv_x = linalg.conjugate_gradient_solve(
                    self._cov_matrix[k] + self._tikhonov_weight *
                    tf.eye(self._context_dim, dtype=self._dtype),
                    tf.linalg.matrix_transpose(observation))
            est_mean_reward = tf.einsum('j,jk->k', self._data_vector[k],
                                        a_inv_x)
            est_rewards.append(est_mean_reward)

            ci = tf.reshape(
                tf.linalg.tensor_diag_part(tf.matmul(observation, a_inv_x)),
                [-1, 1])
            confidence_intervals.append(ci)

        if self._exploration_strategy == ExplorationStrategy.optimistic:
            optimistic_estimates = [
                tf.reshape(mean_reward, [-1, 1]) +
                self._alpha * tf.sqrt(confidence)
                for mean_reward, confidence in zip(est_rewards,
                                                   confidence_intervals)
            ]
            # Keeping the batch dimension during the squeeze, even if batch_size == 1.
            rewards_for_argmax = tf.squeeze(tf.stack(optimistic_estimates,
                                                     axis=-1),
                                            axis=[1])
        elif self._exploration_strategy == ExplorationStrategy.sampling:
            mu_sampler = tfd.Normal(
                loc=tf.stack(est_rewards, axis=-1),
                scale=self._alpha * tf.sqrt(
                    tf.squeeze(tf.stack(confidence_intervals, axis=-1),
                               axis=1)))
            rewards_for_argmax = mu_sampler.sample()
        else:
            raise ValueError('Exploraton strategy %s not implemented.' %
                             self._exploration_strategy)
        if observation_and_action_constraint_splitter is not None:
            chosen_actions = policy_utilities.masked_argmax(
                rewards_for_argmax, mask, output_type=self._action_spec.dtype)
        else:
            chosen_actions = tf.argmax(rewards_for_argmax,
                                       axis=-1,
                                       output_type=self._action_spec.dtype)

        action_distributions = tfp.distributions.Deterministic(
            loc=chosen_actions)

        policy_info = policy_utilities.PolicyInfo(
            predicted_rewards_sampled=(
                rewards_for_argmax
                if policy_utilities.InfoFields.PREDICTED_REWARDS_SAMPLED
                in self._emit_policy_info else ()),
            predicted_rewards_mean=(tf.stack(
                est_rewards,
                axis=-1) if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                                    in self._emit_policy_info else ()))

        return policy_step.PolicyStep(action_distributions, policy_state,
                                      policy_info)
示例#16
0
  def __init__(self,
               encoding_network,
               encoding_dim,
               reward_layer,
               epsilon_greedy,
               actions_from_reward_layer,
               cov_matrix,
               data_vector,
               num_samples,
               time_step_spec=None,
               alpha=1.0,
               emit_policy_info=(),
               emit_log_probability=False,
               observation_and_action_constraint_splitter=None,
               name=None):
    """Initializes `NeuralLinUCBPolicy`.

    Args:
      encoding_network: network that encodes the observations.
      encoding_dim: (int) dimension of the encoded observations.
      reward_layer: final layer that predicts the expected reward per arm.
      epsilon_greedy: (float) representing the probability of choosing a random
        action instead of the greedy action.
      actions_from_reward_layer: (bool) whether to get actions from the reward
        layer or from LinUCB.
      cov_matrix: list of the covariance matrices. There exists one covariance
        matrix per arm.
      data_vector: list of the data vectors. A data vector is a weighted sum
        of the observations, where the weight is the corresponding reward. Each
        arm has its own data vector.
      num_samples: list of number of samples per arm.
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      alpha: (float) non-negative weight multiplying the confidence intervals.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      emit_log_probability: (bool) whether to emit log probabilities.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the bandit policy and 2)
        the mask. The mask should be a 0-1 `Tensor` of shape
        `[batch_size, num_actions]`. This function should also work with a
        `TensorSpec` as input, and should output `TensorSpec` objects for the
        observation and mask.
      name: The name of this policy.
    """
    self._encoding_network = encoding_network
    self._reward_layer = reward_layer
    self._encoding_dim = encoding_dim

    if not isinstance(cov_matrix, (list, tuple)):
      raise ValueError('cov_matrix must be a list of matrices (Tensors).')
    self._cov_matrix = cov_matrix

    if not isinstance(data_vector, (list, tuple)):
      raise ValueError('data_vector must be a list of vectors (Tensors).')
    self._data_vector = data_vector

    if not isinstance(num_samples, (list, tuple)):
      raise ValueError('num_samples must be a list of vectors (Tensors).')
    self._num_samples = num_samples

    self._alpha = alpha
    self._actions_from_reward_layer = actions_from_reward_layer
    self._epsilon_greedy = epsilon_greedy
    self._dtype = self._data_vector[0].dtype

    if len(cov_matrix) != len(data_vector):
      raise ValueError('The size of list cov_matrix must match the size of '
                       'list data_vector. Got {} for cov_matrix and {} '
                       'for data_vector'.format(
                           len(self._cov_matrix), len((data_vector))))
    if len(num_samples) != len(cov_matrix):
      raise ValueError('The size of num_samples must match the size of '
                       'list cov_matrix. Got {} for num_samples and {} '
                       'for cov_matrix'.format(
                           len(self._num_samples), len((cov_matrix))))

    self._num_actions = len(cov_matrix)
    assert self._num_actions
    if observation_and_action_constraint_splitter is not None:
      context_shape = observation_and_action_constraint_splitter(
          time_step_spec.observation)[0].shape.as_list()
    else:
      context_shape = time_step_spec.observation.shape.as_list()
    self._context_dim = (
        tf.compat.dimension_value(context_shape[0]) if context_shape else 1)
    cov_matrix_dim = tf.compat.dimension_value(cov_matrix[0].shape[0])
    if self._encoding_dim != cov_matrix_dim:
      raise ValueError('The dimension of matrix `cov_matrix` must match '
                       'encoding dimension {}.'
                       'Got {} for `cov_matrix`.'.format(
                           self._encoding_dim, cov_matrix_dim))
    data_vector_dim = tf.compat.dimension_value(data_vector[0].shape[0])
    if self._encoding_dim != data_vector_dim:
      raise ValueError('The dimension of vector `data_vector` must match '
                       'encoding  dimension {}. '
                       'Got {} for `data_vector`.'.format(
                           self._encoding_dim, data_vector_dim))
    action_spec = tensor_spec.BoundedTensorSpec(
        shape=(),
        dtype=tf.int32,
        minimum=0,
        maximum=self._num_actions - 1,
        name='action')

    self._emit_policy_info = emit_policy_info
    info_spec = policy_utilities.PolicyInfo()

    super(NeuralLinUCBPolicy, self).__init__(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        emit_log_probability=emit_log_probability,
        observation_and_action_constraint_splitter=(
            observation_and_action_constraint_splitter),
        info_spec=info_spec,
        name=name)
def _get_action_step(action):
    return policy_step.PolicyStep(action=tf.convert_to_tensor(action),
                                  info=policy_utilities.PolicyInfo())
    def _distribution(
            self, time_step: ts.TimeStep,
            policy_state: Sequence[types.TensorSpec]
    ) -> policy_step.PolicyStep:
        observation = time_step.observation
        if self.observation_and_action_constraint_splitter is not None:
            observation, _ = self.observation_and_action_constraint_splitter(
                observation)
        predicted_objective_values_tensor, policy_state = self._predict(
            observation, time_step.step_type, policy_state)
        scalarized_reward = scalarize_objectives(
            predicted_objective_values_tensor, self._scalarizer)
        batch_size = scalarized_reward.shape[0]
        mask = policy_utilities.construct_mask_from_multiple_sources(
            time_step.observation,
            self._observation_and_action_constraint_splitter, (),
            self._expected_num_actions)

        # Argmax.
        if mask is not None:
            actions = policy_utilities.masked_argmax(
                scalarized_reward, mask, output_type=self.action_spec.dtype)
        else:
            actions = tf.argmax(scalarized_reward,
                                axis=-1,
                                output_type=self.action_spec.dtype)

        actions += self._action_offset

        bandit_policy_values = tf.fill(
            [batch_size, 1], policy_utilities.BanditPolicyType.GREEDY)

        if self._accepts_per_arm_features:
            # Saving the features for the chosen action to the policy_info.
            def gather_observation(obs):
                return tf.gather(params=obs, indices=actions, batch_dims=1)

            chosen_arm_features = tf.nest.map_structure(
                gather_observation,
                observation[bandit_spec_utils.PER_ARM_FEATURE_KEY])
            policy_info = policy_utilities.PerArmPolicyInfo(
                log_probability=tf.zeros([batch_size], tf.float32)
                if policy_utilities.InfoFields.LOG_PROBABILITY
                in self._emit_policy_info else (),
                predicted_rewards_mean=(
                    predicted_objective_values_tensor
                    if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                    in self._emit_policy_info else ()),
                bandit_policy_type=(
                    bandit_policy_values
                    if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                    in self._emit_policy_info else ()),
                chosen_arm_features=chosen_arm_features)
        else:
            policy_info = policy_utilities.PolicyInfo(
                log_probability=tf.zeros([batch_size], tf.float32)
                if policy_utilities.InfoFields.LOG_PROBABILITY
                in self._emit_policy_info else (),
                predicted_rewards_mean=(
                    predicted_objective_values_tensor
                    if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                    in self._emit_policy_info else ()),
                bandit_policy_type=(
                    bandit_policy_values
                    if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                    in self._emit_policy_info else ()))

        return policy_step.PolicyStep(
            tfp.distributions.Deterministic(loc=actions), policy_state,
            policy_info)