Ejemplo n.º 1
0
  def testComputeFeasibilityMask(self):
    observation_spec = tensor_spec.TensorSpec([2], tf.float32)
    time_step_spec = ts.time_step_spec(observation_spec)
    action_spec = tensor_spec.BoundedTensorSpec((), tf.int32, 0, 2)
    simple_constraint = SimpleConstraint(time_step_spec, action_spec)

    observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
    feasibility_prob = policy_utilities.compute_feasibility_probability(
        observations, [simple_constraint], batch_size=2, num_actions=3,
        action_mask=None)
    self.assertAllEqual(0.5 * np.ones([2, 3]), self.evaluate(feasibility_prob))
Ejemplo n.º 2
0
  def testComputeFeasibilityMaskWithActionMask(self):
    observation_spec = tensor_spec.TensorSpec([2], tf.float32)
    time_step_spec = ts.time_step_spec(observation_spec)
    action_spec = tensor_spec.BoundedTensorSpec((), tf.int32, 0, 2)
    constraint_net = DummyNet(observation_spec)
    neural_constraint = constraints.NeuralConstraint(
        time_step_spec,
        action_spec,
        constraint_network=constraint_net)

    observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
    action_mask = tf.constant([[0, 0, 1], [0, 1, 0]], dtype=tf.int32)
    feasibility_prob = policy_utilities.compute_feasibility_probability(
        observations, [neural_constraint], batch_size=2, num_actions=3,
        action_mask=action_mask)
    self.assertAllEqual(self.evaluate(tf.cast(action_mask, tf.float32)),
                        self.evaluate(feasibility_prob))
    def _distribution(self, time_step, policy_state):
        observation = time_step.observation
        mask = None
        observation_and_action_constraint_splitter = (
            self.observation_and_action_constraint_splitter)
        if observation_and_action_constraint_splitter is not None:
            observation, mask = observation_and_action_constraint_splitter(
                observation)

        predictions, policy_state = self._reward_network(
            observation, time_step.step_type, policy_state)
        batch_size = tf.shape(predictions)[0]

        if isinstance(self._reward_network,
                      heteroscedastic_q_network.HeteroscedasticQNetwork):
            predicted_reward_values = predictions.q_value_logits
        else:
            predicted_reward_values = predictions

        predicted_reward_values.shape.with_rank_at_least(2)
        predicted_reward_values.shape.with_rank_at_most(3)
        if predicted_reward_values.shape[-1] != self._expected_num_actions:
            raise ValueError(
                'The number of actions ({}) does not match the reward_network output'
                ' size ({}).'.format(self._expected_num_actions,
                                     predicted_reward_values.shape[1]))

        if self._constraints:
            # Action feasibility computation.
            feasibility_prob = policy_utilities.compute_feasibility_probability(
                observation, self._constraints, batch_size,
                self._expected_num_actions, mask)
            # Probabilistic masking.
            mask = tfp.distributions.Bernoulli(probs=feasibility_prob).sample()

        # Argmax.
        if self._constraints or (observation_and_action_constraint_splitter
                                 is not None):
            actions = policy_utilities.masked_argmax(
                predicted_reward_values,
                mask,
                output_type=self.action_spec.dtype)
        else:
            actions = tf.argmax(predicted_reward_values,
                                axis=-1,
                                output_type=self.action_spec.dtype)

        actions += self._action_offset

        bandit_policy_values = tf.fill(
            [batch_size, 1], policy_utilities.BanditPolicyType.GREEDY)

        if self._accepts_per_arm_features:
            # Saving the features for the chosen action to the policy_info.
            def gather_observation(obs):
                return tf.gather(params=obs, indices=actions, batch_dims=1)

            chosen_arm_features = tf.nest.map_structure(
                gather_observation,
                observation[bandit_spec_utils.PER_ARM_FEATURE_KEY])
            policy_info = policy_utilities.PerArmPolicyInfo(
                predicted_rewards_mean=(
                    predicted_reward_values
                    if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                    in self._emit_policy_info else ()),
                bandit_policy_type=(
                    bandit_policy_values
                    if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                    in self._emit_policy_info else ()),
                chosen_arm_features=chosen_arm_features)
        else:
            policy_info = policy_utilities.PolicyInfo(
                predicted_rewards_mean=(
                    predicted_reward_values
                    if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
                    in self._emit_policy_info else ()),
                bandit_policy_type=(
                    bandit_policy_values
                    if policy_utilities.InfoFields.BANDIT_POLICY_TYPE
                    in self._emit_policy_info else ()))

        return policy_step.PolicyStep(
            tfp.distributions.Deterministic(loc=actions), policy_state,
            policy_info)