def testBadMask(self):
     input_tensor = tf.reshape(tf.range(12, dtype=tf.float32), shape=[3, 4])
     mask = [[1, 0, 0, 1], [0, 0, 0, 0], [1, 0, 1, 1]]
     expected = [3, -1, 3]
     actual = self.evaluate(
         policy_utilities.masked_argmax(input_tensor, tf.constant(mask)))
     self.assertAllEqual(actual, expected)
Пример #2
0
    def _get_actions_from_reward_layer(self, encoded_observation, mask):
        # Get the predicted expected reward.
        est_mean_reward = self._reward_layer(encoded_observation)
        if mask is None:
            greedy_actions = tf.argmax(est_mean_reward,
                                       axis=-1,
                                       output_type=tf.int32)
        else:
            greedy_actions = policy_utilities.masked_argmax(
                est_mean_reward, mask, output_type=tf.int32)

        # Add epsilon greedy on top, if needed.
        if self._epsilon_greedy:
            batch_size = (tf.compat.dimension_value(
                encoded_observation.shape[0])
                          or tf.shape(encoded_observation)[0])
            if mask is None:
                random_actions = tf.random.uniform([batch_size],
                                                   maxval=self._num_actions,
                                                   dtype=tf.int32)
            else:
                zero_logits = tf.cast(tf.zeros_like(mask), tf.float32)
                masked_categorical = masked.MaskedCategorical(zero_logits,
                                                              mask,
                                                              dtype=tf.int32)
                random_actions = masked_categorical.sample()

            rng = tf.random.uniform([batch_size], maxval=1.0)
            cond = tf.greater(rng, self._epsilon_greedy)
            chosen_actions = tf.compat.v1.where(cond, greedy_actions,
                                                random_actions)
        else:
            chosen_actions = greedy_actions

        return chosen_actions
Пример #3
0
    def _get_actions_from_linucb(self, encoded_observation, mask):
        encoded_observation = tf.cast(encoded_observation, dtype=self._dtype)

        p_values = []
        for k in range(self._num_actions):
            a_inv_x = linalg.conjugate_gradient_solve(
                self._cov_matrix[k] +
                tf.eye(self._encoding_dim, dtype=self._dtype),
                tf.linalg.matrix_transpose(encoded_observation))
            mean_reward_est = tf.einsum('j,jk->k', self._data_vector[k],
                                        a_inv_x)

            ci = tf.reshape(
                tf.linalg.tensor_diag_part(
                    tf.matmul(encoded_observation, a_inv_x)), [-1, 1])
            p_values.append(
                tf.reshape(mean_reward_est, [-1, 1]) +
                self._alpha * tf.sqrt(ci))

        stacked_p_values = tf.squeeze(tf.stack(p_values, axis=-1), axis=[1])
        if mask is None:
            chosen_actions = tf.argmax(stacked_p_values,
                                       axis=-1,
                                       output_type=tf.int32)
        else:
            chosen_actions = policy_utilities.masked_argmax(
                stacked_p_values, mask, output_type=tf.int32)
        return chosen_actions
 def _distribution(self, time_step, policy_state):
     observation = time_step.observation
     if self._observation_and_action_constraint_splitter:
         observation, mask = self._observation_and_action_constraint_splitter(
             observation)
     predicted_reward_values, policy_state = self._reward_network(
         observation, time_step.step_type, policy_state)
     predicted_reward_values.shape.with_rank_at_least(2)
     predicted_reward_values.shape.with_rank_at_most(3)
     if predicted_reward_values.shape[-1] != self._expected_num_actions:
         raise ValueError(
             'The number of actions ({}) does not match the reward_network output'
             ' size ({}.)'.format(self._expected_num_actions,
                                  predicted_reward_values.shape[1]))
     if self._observation_and_action_constraint_splitter:
         actions = policy_utilities.masked_argmax(
             predicted_reward_values,
             mask,
             output_type=self.action_spec.dtype)
     else:
         actions = tf.argmax(predicted_reward_values,
                             axis=-1,
                             output_type=self.action_spec.dtype)
     actions += self._action_offset
     return policy_step.PolicyStep(
         tfp.distributions.Deterministic(loc=actions), policy_state)
    def _action(self, time_step, policy_state, seed):
        seed_stream = tfd.SeedStream(seed=seed, salt='ts_policy')
        observation = time_step.observation
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)
        if observation_and_action_constraint_splitter is not None:
            observation, mask = observation_and_action_constraint_splitter(
                observation)

        observation = tf.cast(observation,
                              dtype=self._parameter_estimators[0].dtype)
        mean_estimates, scales = _get_means_and_variances(
            self._parameter_estimators, self._weight_covariance_matrices,
            observation)
        mu_sampler = tfd.Normal(loc=tf.stack(mean_estimates, axis=-1),
                                scale=tf.sqrt(tf.stack(scales, axis=-1)))
        reward_samples = mu_sampler.sample(seed=seed_stream())
        if observation_and_action_constraint_splitter is not None:
            actions = policy_utilities.masked_argmax(
                reward_samples, mask, output_type=self._action_spec.dtype)
        else:
            actions = tf.argmax(reward_samples,
                                axis=-1,
                                output_type=self._action_spec.dtype)
        return policy_step.PolicyStep(actions, policy_state)
    def _distribution(self, time_step, policy_state):
        observation = time_step.observation
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)
        if observation_and_action_constraint_splitter is not None:
            observation, mask = observation_and_action_constraint_splitter(
                observation)
        predicted_reward_values, policy_state = self._reward_network(
            observation, time_step.step_type, policy_state)
        predicted_reward_values.shape.with_rank_at_least(2)
        predicted_reward_values.shape.with_rank_at_most(3)
        if predicted_reward_values.shape[-1] != self._expected_num_actions:
            raise ValueError(
                'The number of actions ({}) does not match the reward_network output'
                ' size ({}.)'.format(self._expected_num_actions,
                                     predicted_reward_values.shape[1]))
        if observation_and_action_constraint_splitter is not None:
            actions = policy_utilities.masked_argmax(
                predicted_reward_values,
                mask,
                output_type=self.action_spec.dtype)
        else:
            actions = tf.argmax(predicted_reward_values,
                                axis=-1,
                                output_type=self.action_spec.dtype)
        actions += self._action_offset

        policy_info = policy_utilities.PolicyInfo(predicted_rewards_mean=(
            predicted_reward_values if policy_utilities.InfoFields.
            PREDICTED_REWARDS_MEAN in self._emit_policy_info else ()))

        return policy_step.PolicyStep(
            tfp.distributions.Deterministic(loc=actions), policy_state,
            policy_info)
Пример #7
0
 def testBadMask(self):
     input_tensor = tf.reshape(tf.range(12, dtype=tf.float32), shape=[3, 4])
     mask = [[1, 0, 0, 1], [0, 0, 0, 0], [1, 0, 1, 1]]
     with self.assertRaises(tf.errors.InvalidArgumentError):
         self.evaluate(
             policy_utilities.masked_argmax(input_tensor,
                                            tf.constant(mask)))
Пример #8
0
    def _distribution(self, time_step, policy_state):
        observation = time_step.observation
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)
        if observation_and_action_constraint_splitter is not None:
            observation, mask = observation_and_action_constraint_splitter(
                observation)
        batch_size = tf.shape(observation)[0]

        predictions, policy_state = self._reward_network(
            observation, time_step.step_type, policy_state)

        if isinstance(self._reward_network,
                      heteroscedastic_q_network.HeteroscedasticQNetwork):
            predicted_reward_values = predictions.q_value_logits
        else:
            predicted_reward_values = predictions

        predicted_reward_values.shape.with_rank_at_least(2)
        predicted_reward_values.shape.with_rank_at_most(3)
        if predicted_reward_values.shape[-1] != self._expected_num_actions:
            raise ValueError(
                'The number of actions ({}) does not match the reward_network output'
                ' size ({}.)'.format(self._expected_num_actions,
                                     predicted_reward_values.shape[1]))
        if observation_and_action_constraint_splitter is not None:
            actions = policy_utilities.masked_argmax(
                predicted_reward_values,
                mask,
                output_type=self.action_spec.dtype)
        else:
            actions = tf.argmax(predicted_reward_values,
                                axis=-1,
                                output_type=self.action_spec.dtype)
        actions += self._action_offset

        bandit_policy_values = tf.fill(
            [batch_size, 1], policy_utilities.BanditPolicyType.GREEDY)

        policy_info = policy_utilities.PolicyInfo(
            predicted_rewards_mean=(
                predicted_reward_values
                if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                in self._emit_policy_info else ()),
            bandit_policy_type=(bandit_policy_values if
                                policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                                in self._emit_policy_info else ()))

        return policy_step.PolicyStep(
            tfp.distributions.Deterministic(loc=actions), policy_state,
            policy_info)
Пример #9
0
    def _get_actions_from_linucb(
        self, encoded_observation: types.Float, mask: Optional[types.Tensor]
    ) -> Tuple[types.Int, types.Float, types.Float]:
        encoded_observation = tf.cast(encoded_observation, dtype=self._dtype)

        p_values = []
        est_rewards = []
        for k in range(self._num_actions):
            encoded_observation_for_arm = self._get_encoded_observation_for_arm(
                encoded_observation, k)
            model_index = policy_utilities.get_model_index(
                k, self._accepts_per_arm_features)
            a_inv_x = linalg.conjugate_gradient_solve(
                self._cov_matrix[model_index] +
                tf.eye(self._encoding_dim, dtype=self._dtype),
                tf.linalg.matrix_transpose(encoded_observation_for_arm))
            mean_reward_est = tf.einsum('j,jk->k',
                                        self._data_vector[model_index],
                                        a_inv_x)
            est_rewards.append(mean_reward_est)

            ci = tf.reshape(
                tf.linalg.tensor_diag_part(
                    tf.matmul(encoded_observation_for_arm, a_inv_x)), [-1, 1])
            p_values.append(
                tf.reshape(mean_reward_est, [-1, 1]) +
                self._alpha * tf.sqrt(ci))

        stacked_p_values = tf.squeeze(tf.stack(p_values, axis=-1), axis=[1])
        if mask is None:
            chosen_actions = tf.argmax(stacked_p_values,
                                       axis=-1,
                                       output_type=tf.int32)
        else:
            chosen_actions = policy_utilities.masked_argmax(
                stacked_p_values, mask, output_type=tf.int32)

        est_mean_reward = tf.cast(tf.stack(est_rewards, axis=-1), tf.float32)
        return chosen_actions, est_mean_reward, tf.cast(
            stacked_p_values, tf.float32)
    def _action(self, time_step, policy_state, seed):
        seed_stream = tfp.util.SeedStream(seed=seed, salt='ts_policy')
        observation = time_step.observation
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)
        if observation_and_action_constraint_splitter is not None:
            observation, mask = observation_and_action_constraint_splitter(
                observation)

        observation = tf.cast(observation,
                              dtype=self._parameter_estimators[0].dtype)
        mean_estimates, scales = _get_means_and_variances(
            self._parameter_estimators, self._weight_covariance_matrices,
            observation)
        mu_sampler = tfd.Normal(loc=tf.stack(mean_estimates, axis=-1),
                                scale=tf.sqrt(tf.stack(scales, axis=-1)))
        reward_samples = mu_sampler.sample(seed=seed_stream())
        if observation_and_action_constraint_splitter is not None:
            actions = policy_utilities.masked_argmax(
                reward_samples, mask, output_type=self._action_spec.dtype)
        else:
            actions = tf.argmax(reward_samples,
                                axis=-1,
                                output_type=self._action_spec.dtype)

        policy_info = policy_utilities.PolicyInfo(
            predicted_rewards_sampled=(
                reward_samples
                if policy_utilities.InfoFields.PREDICTED_REWARDS_SAMPLED
                in self._emit_policy_info else ()),
            predicted_rewards_mean=(tf.stack(
                mean_estimates,
                axis=-1) if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                                    in self._emit_policy_info else ()))

        return policy_step.PolicyStep(actions, policy_state, policy_info)
Пример #11
0
    def _distribution(self, time_step, policy_state):
        observation = time_step.observation
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)
        if observation_and_action_constraint_splitter is not None:
            observation, mask = observation_and_action_constraint_splitter(
                observation)
        observation = tf.nest.map_structure(
            lambda o: tf.cast(o, dtype=self._dtype), observation)
        global_observation, arm_observations = self._split_observation(
            observation)

        if self._add_bias:
            # The bias is added via a constant 1 feature.
            global_observation = tf.concat([
                global_observation,
                tf.ones([tf.shape(global_observation)[0], 1],
                        dtype=self._dtype)
            ],
                                           axis=1)
        # Check the shape of the observation matrix. The observations can be
        # batched.
        if not global_observation.shape.is_compatible_with(
            [None, self._global_context_dim]):
            raise ValueError(
                'Global observation shape is expected to be {}. Got {}.'.
                format([None, self._global_context_dim],
                       global_observation.shape.as_list()))
        global_observation = tf.reshape(global_observation,
                                        [-1, self._global_context_dim])

        est_rewards = []
        confidence_intervals = []
        for k in range(self._num_actions):
            current_observation = self._get_current_observation(
                global_observation, arm_observations, k)
            model_index = self._get_model_index(k)
            if self._use_eigendecomp:
                q_t_b = tf.matmul(
                    self._eig_matrix[model_index],
                    tf.linalg.matrix_transpose(current_observation),
                    transpose_a=True)
                lambda_inv = tf.divide(
                    tf.ones_like(self._eig_vals[model_index]),
                    self._eig_vals[model_index] + self._tikhonov_weight)
                a_inv_x = tf.matmul(self._eig_matrix[model_index],
                                    tf.einsum('j,jk->jk', lambda_inv, q_t_b))
            else:
                a_inv_x = linalg.conjugate_gradient_solve(
                    self._cov_matrix[model_index] + self._tikhonov_weight *
                    tf.eye(self._overall_context_dim, dtype=self._dtype),
                    tf.linalg.matrix_transpose(current_observation))
            est_mean_reward = tf.einsum('j,jk->k',
                                        self._data_vector[model_index],
                                        a_inv_x)
            est_rewards.append(est_mean_reward)

            ci = tf.reshape(
                tf.linalg.tensor_diag_part(
                    tf.matmul(current_observation, a_inv_x)), [-1, 1])
            confidence_intervals.append(ci)

        if self._exploration_strategy == ExplorationStrategy.optimistic:
            optimistic_estimates = [
                tf.reshape(mean_reward, [-1, 1]) +
                self._alpha * tf.sqrt(confidence)
                for mean_reward, confidence in zip(est_rewards,
                                                   confidence_intervals)
            ]
            # Keeping the batch dimension during the squeeze, even if batch_size == 1.
            rewards_for_argmax = tf.squeeze(tf.stack(optimistic_estimates,
                                                     axis=-1),
                                            axis=[1])
        elif self._exploration_strategy == ExplorationStrategy.sampling:
            mu_sampler = tfd.Normal(
                loc=tf.stack(est_rewards, axis=-1),
                scale=self._alpha * tf.sqrt(
                    tf.squeeze(tf.stack(confidence_intervals, axis=-1),
                               axis=1)))
            rewards_for_argmax = mu_sampler.sample()
        else:
            raise ValueError('Exploraton strategy %s not implemented.' %
                             self._exploration_strategy)
        if observation_and_action_constraint_splitter is not None:
            chosen_actions = policy_utilities.masked_argmax(
                rewards_for_argmax, mask, output_type=self._action_spec.dtype)
        else:
            chosen_actions = tf.argmax(rewards_for_argmax,
                                       axis=-1,
                                       output_type=self._action_spec.dtype)

        action_distributions = tfp.distributions.Deterministic(
            loc=chosen_actions)

        policy_info = self._populate_policy_info(arm_observations,
                                                 chosen_actions,
                                                 rewards_for_argmax,
                                                 est_rewards)

        return policy_step.PolicyStep(action_distributions, policy_state,
                                      policy_info)
 def testMaskedArgmax(self, input_tensor, mask, expected):
     actual = policy_utilities.masked_argmax(
         tf.constant(input_tensor, dtype=tf.float32), tf.constant(mask))
     self.assertAllEqual(actual, expected)
Пример #13
0
    def _distribution(self, time_step, policy_state):
        observation = time_step.observation
        if self.observation_and_action_constraint_splitter is not None:
            observation, _ = self.observation_and_action_constraint_splitter(
                observation)

        predictions, policy_state = self._reward_network(
            observation, time_step.step_type, policy_state)
        batch_size = tf.shape(predictions)[0]

        if isinstance(self._reward_network,
                      heteroscedastic_q_network.HeteroscedasticQNetwork):
            predicted_reward_values = predictions.q_value_logits
        else:
            predicted_reward_values = predictions

        predicted_reward_values.shape.with_rank_at_least(2)
        predicted_reward_values.shape.with_rank_at_most(3)
        if predicted_reward_values.shape[
                -1] is not None and predicted_reward_values.shape[
                    -1] != self._expected_num_actions:
            raise ValueError(
                'The number of actions ({}) does not match the reward_network output'
                ' size ({}).'.format(self._expected_num_actions,
                                     predicted_reward_values.shape[1]))

        mask = constr.construct_mask_from_multiple_sources(
            time_step.observation,
            self._observation_and_action_constraint_splitter,
            self._constraints, self._expected_num_actions)

        # Argmax.
        if mask is not None:
            actions = policy_utilities.masked_argmax(
                predicted_reward_values,
                mask,
                output_type=self.action_spec.dtype)
        else:
            actions = tf.argmax(predicted_reward_values,
                                axis=-1,
                                output_type=self.action_spec.dtype)

        actions += self._action_offset

        bandit_policy_values = tf.fill(
            [batch_size, 1], policy_utilities.BanditPolicyType.GREEDY)

        if self._accepts_per_arm_features:
            # Saving the features for the chosen action to the policy_info.
            def gather_observation(obs):
                return tf.gather(params=obs, indices=actions, batch_dims=1)

            chosen_arm_features = tf.nest.map_structure(
                gather_observation,
                observation[bandit_spec_utils.PER_ARM_FEATURE_KEY])
            policy_info = policy_utilities.PerArmPolicyInfo(
                log_probability=tf.zeros([batch_size], tf.float32)
                if policy_utilities.InfoFields.LOG_PROBABILITY
                in self._emit_policy_info else (),
                predicted_rewards_mean=(
                    predicted_reward_values
                    if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                    in self._emit_policy_info else ()),
                bandit_policy_type=(
                    bandit_policy_values
                    if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                    in self._emit_policy_info else ()),
                chosen_arm_features=chosen_arm_features)
        else:
            policy_info = policy_utilities.PolicyInfo(
                log_probability=tf.zeros([batch_size], tf.float32)
                if policy_utilities.InfoFields.LOG_PROBABILITY
                in self._emit_policy_info else (),
                predicted_rewards_mean=(
                    predicted_reward_values
                    if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                    in self._emit_policy_info else ()),
                bandit_policy_type=(
                    bandit_policy_values
                    if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                    in self._emit_policy_info else ()))

        return policy_step.PolicyStep(
            tfp.distributions.Deterministic(loc=actions), policy_state,
            policy_info)
    def _distribution(
            self, time_step: ts.TimeStep,
            policy_state: Sequence[types.TensorSpec]
    ) -> policy_step.PolicyStep:
        observation = time_step.observation
        if self.observation_and_action_constraint_splitter is not None:
            observation, _ = self.observation_and_action_constraint_splitter(
                observation)
        predicted_objective_values_tensor, policy_state = self._predict(
            observation, time_step.step_type, policy_state)
        scalarized_reward = scalarize_objectives(
            predicted_objective_values_tensor, self._scalarizer)
        batch_size = scalarized_reward.shape[0]
        mask = policy_utilities.construct_mask_from_multiple_sources(
            time_step.observation,
            self._observation_and_action_constraint_splitter, (),
            self._expected_num_actions)

        # Argmax.
        if mask is not None:
            actions = policy_utilities.masked_argmax(
                scalarized_reward, mask, output_type=self.action_spec.dtype)
        else:
            actions = tf.argmax(scalarized_reward,
                                axis=-1,
                                output_type=self.action_spec.dtype)

        actions += self._action_offset

        bandit_policy_values = tf.fill(
            [batch_size, 1], policy_utilities.BanditPolicyType.GREEDY)

        if self._accepts_per_arm_features:
            # Saving the features for the chosen action to the policy_info.
            def gather_observation(obs):
                return tf.gather(params=obs, indices=actions, batch_dims=1)

            chosen_arm_features = tf.nest.map_structure(
                gather_observation,
                observation[bandit_spec_utils.PER_ARM_FEATURE_KEY])
            policy_info = policy_utilities.PerArmPolicyInfo(
                log_probability=tf.zeros([batch_size], tf.float32)
                if policy_utilities.InfoFields.LOG_PROBABILITY
                in self._emit_policy_info else (),
                predicted_rewards_mean=(
                    predicted_objective_values_tensor
                    if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                    in self._emit_policy_info else ()),
                bandit_policy_type=(
                    bandit_policy_values
                    if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                    in self._emit_policy_info else ()),
                chosen_arm_features=chosen_arm_features)
        else:
            policy_info = policy_utilities.PolicyInfo(
                log_probability=tf.zeros([batch_size], tf.float32)
                if policy_utilities.InfoFields.LOG_PROBABILITY
                in self._emit_policy_info else (),
                predicted_rewards_mean=(
                    predicted_objective_values_tensor
                    if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                    in self._emit_policy_info else ()),
                bandit_policy_type=(
                    bandit_policy_values
                    if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                    in self._emit_policy_info else ()))

        return policy_step.PolicyStep(
            tfp.distributions.Deterministic(loc=actions), policy_state,
            policy_info)
Пример #15
0
    def _distribution(self, time_step, policy_state):
        observation = time_step.observation
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)
        if observation_and_action_constraint_splitter is not None:
            observation, mask = observation_and_action_constraint_splitter(
                observation)
        observation = tf.cast(observation, dtype=self._dtype)
        if self._add_bias:
            # The bias is added via a constant 1 feature.
            observation = tf.concat([
                observation,
                tf.ones([tf.shape(observation)[0], 1], dtype=self._dtype)
            ],
                                    axis=1)
        # Check the shape of the observation matrix. The observations can be
        # batched.
        if not observation.shape.is_compatible_with([None, self._context_dim]):
            raise ValueError(
                'Observation shape is expected to be {}. Got {}.'.format(
                    [None, self._context_dim], observation.shape.as_list()))
        observation = tf.reshape(observation, [-1, self._context_dim])

        p_values = []
        est_rewards = []
        for k in range(self._num_actions):
            if self._use_eigendecomp:
                q_t_b = tf.matmul(self._eig_matrix[k],
                                  tf.linalg.matrix_transpose(observation),
                                  transpose_a=True)
                lambda_inv = tf.divide(
                    tf.ones_like(self._eig_vals[k]),
                    self._eig_vals[k] + self._tikhonov_weight)
                a_inv_x = tf.matmul(self._eig_matrix[k],
                                    tf.einsum('j,jk->jk', lambda_inv, q_t_b))
            else:
                a_inv_x = linalg.conjugate_gradient_solve(
                    self._cov_matrix[k] +
                    self._tikhonov_weight * tf.eye(self._context_dim),
                    tf.linalg.matrix_transpose(observation))
            est_mean_reward = tf.einsum('j,jk->k', self._data_vector[k],
                                        a_inv_x)
            est_rewards.append(est_mean_reward)

            ci = tf.reshape(
                tf.linalg.tensor_diag_part(tf.matmul(observation, a_inv_x)),
                [-1, 1])
            p_values.append(
                tf.reshape(est_mean_reward, [-1, 1]) +
                self._alpha * tf.sqrt(ci))

        # Keeping the batch dimension during the squeeze, even if batch_size == 1.
        optimistic_reward_estimates = tf.squeeze(tf.stack(p_values, axis=-1),
                                                 axis=[1])
        if observation_and_action_constraint_splitter is not None:
            chosen_actions = policy_utilities.masked_argmax(
                optimistic_reward_estimates,
                mask,
                output_type=self._action_spec.dtype)
        else:
            chosen_actions = tf.argmax(optimistic_reward_estimates,
                                       axis=-1,
                                       output_type=self._action_spec.dtype)
        action_distributions = tfp.distributions.Deterministic(
            loc=chosen_actions)

        policy_info = policy_utilities.PolicyInfo(
            predicted_rewards_mean=tf.stack(est_rewards, axis=-1)
            if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in
            self._emit_policy_info else ())

        return policy_step.PolicyStep(action_distributions, policy_state,
                                      policy_info)