def testComputeFeasibilityMask(self): observation_spec = tensor_spec.TensorSpec([2], tf.float32) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.BoundedTensorSpec((), tf.int32, 0, 2) simple_constraint = SimpleConstraint(time_step_spec, action_spec) observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) feasibility_prob = policy_utilities.compute_feasibility_probability( observations, [simple_constraint], batch_size=2, num_actions=3, action_mask=None) self.assertAllEqual(0.5 * np.ones([2, 3]), self.evaluate(feasibility_prob))
def testComputeFeasibilityMaskWithActionMask(self): observation_spec = tensor_spec.TensorSpec([2], tf.float32) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.BoundedTensorSpec((), tf.int32, 0, 2) constraint_net = DummyNet(observation_spec) neural_constraint = constraints.NeuralConstraint( time_step_spec, action_spec, constraint_network=constraint_net) observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) action_mask = tf.constant([[0, 0, 1], [0, 1, 0]], dtype=tf.int32) feasibility_prob = policy_utilities.compute_feasibility_probability( observations, [neural_constraint], batch_size=2, num_actions=3, action_mask=action_mask) self.assertAllEqual(self.evaluate(tf.cast(action_mask, tf.float32)), self.evaluate(feasibility_prob))
def _distribution(self, time_step, policy_state): observation = time_step.observation mask = None observation_and_action_constraint_splitter = ( self.observation_and_action_constraint_splitter) if observation_and_action_constraint_splitter is not None: observation, mask = observation_and_action_constraint_splitter( observation) predictions, policy_state = self._reward_network( observation, time_step.step_type, policy_state) batch_size = tf.shape(predictions)[0] if isinstance(self._reward_network, heteroscedastic_q_network.HeteroscedasticQNetwork): predicted_reward_values = predictions.q_value_logits else: predicted_reward_values = predictions predicted_reward_values.shape.with_rank_at_least(2) predicted_reward_values.shape.with_rank_at_most(3) if predicted_reward_values.shape[-1] != self._expected_num_actions: raise ValueError( 'The number of actions ({}) does not match the reward_network output' ' size ({}).'.format(self._expected_num_actions, predicted_reward_values.shape[1])) if self._constraints: # Action feasibility computation. feasibility_prob = policy_utilities.compute_feasibility_probability( observation, self._constraints, batch_size, self._expected_num_actions, mask) # Probabilistic masking. mask = tfp.distributions.Bernoulli(probs=feasibility_prob).sample() # Argmax. if self._constraints or (observation_and_action_constraint_splitter is not None): actions = policy_utilities.masked_argmax( predicted_reward_values, mask, output_type=self.action_spec.dtype) else: actions = tf.argmax(predicted_reward_values, axis=-1, output_type=self.action_spec.dtype) actions += self._action_offset bandit_policy_values = tf.fill( [batch_size, 1], policy_utilities.BanditPolicyType.GREEDY) if self._accepts_per_arm_features: # Saving the features for the chosen action to the policy_info. def gather_observation(obs): return tf.gather(params=obs, indices=actions, batch_dims=1) chosen_arm_features = tf.nest.map_structure( gather_observation, observation[bandit_spec_utils.PER_ARM_FEATURE_KEY]) policy_info = policy_utilities.PerArmPolicyInfo( predicted_rewards_mean=( predicted_reward_values if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in self._emit_policy_info else ()), bandit_policy_type=( bandit_policy_values if policy_utilities.InfoFields.BANDIT_POLICY_TYPE in self._emit_policy_info else ()), chosen_arm_features=chosen_arm_features) else: policy_info = policy_utilities.PolicyInfo( predicted_rewards_mean=( predicted_reward_values if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in self._emit_policy_info else ()), bandit_policy_type=( bandit_policy_values if policy_utilities.InfoFields.BANDIT_POLICY_TYPE in self._emit_policy_info else ())) return policy_step.PolicyStep( tfp.distributions.Deterministic(loc=actions), policy_state, policy_info)