def testWrongPolicyInfoType(self):
   dims = (10, 1)
   log_probability = tf.fill(dims, value=-0.5)
   info = policy_step.PolicyInfo(log_probability=log_probability)
   input_tensor = tf.fill(dims, value=_GREEDY)
   result = policy_utilities.set_bandit_policy_type(info, input_tensor)
   self.assertNotIsInstance(result, policy_utilities.PolicyInfo)
   self.assertAllEqual(info.log_probability, result.log_probability)
    def _action(self, time_step, policy_state, seed):
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)

        if observation_and_action_constraint_splitter is not None:
            _, mask = observation_and_action_constraint_splitter(
                time_step.observation)

            zero_logits = tf.cast(tf.zeros_like(mask), tf.float32)
            masked_categorical = masked.MaskedCategorical(zero_logits, mask)
            #Modified to accomodate scalar action spaces
            #action_ = tf.cast(masked_categorical.sample() + self.action_spec.minimum,
            #                  self.action_spec.dtype)
            action_ = tf.reshape(
                tf.cast(masked_categorical.sample() + self.action_spec.minimum,
                        self.action_spec.dtype), [1])

            # If the action spec says each action should be shaped (1,), add another
            # dimension so the final shape is (B, 1) rather than (B,).
            if self.action_spec.shape.rank == 1:
                action_ = tf.expand_dims(action_, axis=-1)
        else:
            outer_dims = nest_utils.get_outer_shape(time_step,
                                                    self._time_step_spec)

            action_ = tensor_spec.sample_spec_nest(self._action_spec,
                                                   seed=seed,
                                                   outer_dims=outer_dims)

        # TODO(b/78181147): Investigate why this control dependency is required.
        if time_step is not None:
            with tf.control_dependencies(tf.nest.flatten(time_step)):
                action_ = tf.nest.map_structure(tf.identity, action_)
        step = policy_step.PolicyStep(action_, policy_state)

        if self.emit_log_probability:
            if observation_and_action_constraint_splitter is not None:
                log_probability = masked_categorical.log_prob(
                    action_ - self.action_spec.minimum)
            else:
                action_probability = tf.nest.map_structure(
                    _uniform_probability, self._action_spec)
                log_probability = tf.nest.map_structure(
                    tf.math.log, action_probability)

            info = policy_step.PolicyInfo(log_probability=log_probability)
            return step._replace(info=info)

        return step
  def _action(self, time_step, policy_state, seed):
    outer_dims = nest_utils.get_outer_shape(time_step, self._time_step_spec)

    action_ = tensor_spec.sample_spec_nest(
        self._action_spec, seed=seed, outer_dims=outer_dims)
    # TODO(b/78181147): Investigate why this control dependency is required.
    if time_step is not None:
      with tf.control_dependencies(tf.nest.flatten(time_step)):
        action_ = tf.nest.map_structure(tf.identity, action_)
    step = policy_step.PolicyStep(action_, policy_state)

    if self.emit_log_probability:
      action_probability = tf.nest.map_structure(_uniform_probability,
                                                 self._action_spec)
      log_probability = tf.nest.map_structure(tf.math.log, action_probability)
      info = policy_step.PolicyInfo(log_probability=log_probability)
      return step._replace(info=info)

    return step