Esempio n. 1
0
 def testMixturePolicyNegativeProb(self):
     context_dim = 11
     observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32)
     time_step_spec = ts.time_step_spec(observation_spec)
     action_spec = tensor_spec.BoundedTensorSpec(shape=(),
                                                 dtype=tf.int32,
                                                 minimum=0,
                                                 maximum=9,
                                                 name='action')
     sub_policies = [
         ConstantPolicy(action_spec, time_step_spec, i) for i in range(10)
     ]
     weights = [0, 0, 0.2, 0, 0, -0.3, 0, 0, 0.5, 0]
     policy = mixture_policy.MixturePolicy(weights, sub_policies)
     batch_size = 15
     time_step = ts.TimeStep(
         tf.constant(ts.StepType.FIRST,
                     dtype=tf.int32,
                     shape=[batch_size],
                     name='step_type'),
         tf.constant(0.0,
                     dtype=tf.float32,
                     shape=[batch_size],
                     name='reward'),
         tf.constant(1.0,
                     dtype=tf.float32,
                     shape=[batch_size],
                     name='discount'),
         tf.constant(list(range(batch_size * context_dim)),
                     dtype=tf.float32,
                     shape=[batch_size, context_dim],
                     name='observation'))
     with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                  'Negative probability'):
         policy.action(time_step)
Esempio n. 2
0
 def testMixturePolicyDynamicBatchSize(self):
     context_dim = 35
     observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32)
     time_step_spec = ts.time_step_spec(observation_spec)
     action_spec = tensor_spec.BoundedTensorSpec(shape=(),
                                                 dtype=tf.int32,
                                                 minimum=0,
                                                 maximum=9,
                                                 name='action')
     sub_policies = [
         ConstantPolicy(action_spec, time_step_spec, i) for i in range(10)
     ]
     weights = [0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.5, 0]
     dist = tfd.Categorical(probs=weights)
     policy = mixture_policy.MixturePolicy(dist, sub_policies)
     batch_size = tf.random.uniform(shape=(),
                                    minval=10,
                                    maxval=15,
                                    dtype=tf.int32)
     time_step = ts.TimeStep(
         tf.fill(tf.expand_dims(batch_size, axis=0),
                 ts.StepType.FIRST,
                 name='step_type'),
         tf.zeros(shape=[batch_size], dtype=tf.float32, name='reward'),
         tf.ones(shape=[batch_size], dtype=tf.float32, name='discount'),
         tf.reshape(tf.range(tf.cast(batch_size * context_dim,
                                     dtype=tf.float32),
                             dtype=tf.float32),
                    shape=[-1, context_dim],
                    name='observation'))
     action_step = policy.action(time_step)
     actions, bsize = self.evaluate([action_step.action, batch_size])
     self.assertAllEqual(actions.shape, [bsize])
     self.assertAllInSet(actions, [2, 5, 8])
     saver = policy_saver.PolicySaver(policy)
     location = os.path.join(self.get_temp_dir(), 'saved_policy')
     saver.save(location)
     loaded_policy = tf.compat.v2.saved_model.load(location)
     new_batch_size = 3
     new_time_step = ts.TimeStep(
         tf.fill(tf.expand_dims(new_batch_size, axis=0),
                 ts.StepType.FIRST,
                 name='step_type'),
         tf.zeros(shape=[new_batch_size], dtype=tf.float32, name='reward'),
         tf.ones(shape=[new_batch_size], dtype=tf.float32, name='discount'),
         tf.reshape(tf.range(tf.cast(new_batch_size * context_dim,
                                     dtype=tf.float32),
                             dtype=tf.float32),
                    shape=[-1, context_dim],
                    name='observation'))
     new_action = self.evaluate(loaded_policy.action(new_time_step).action)
     self.assertAllEqual(new_action.shape, [new_batch_size])
     self.assertAllInSet(new_action, [2, 5, 8])
Esempio n. 3
0
 def testMixturePolicyInconsistentSpecs(self):
     context_dim = 11
     observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32)
     time_step_spec = ts.time_step_spec(observation_spec)
     action_spec = tensor_spec.BoundedTensorSpec(shape=(),
                                                 dtype=tf.int32,
                                                 minimum=0,
                                                 maximum=9,
                                                 name='action')
     sub_policies = [
         ConstantPolicy(action_spec, time_step_spec, i) for i in range(9)
     ]
     wrong_obs_spec = tensor_spec.TensorSpec([context_dim + 1], tf.float32)
     wrong_time_step_spec = ts.time_step_spec(wrong_obs_spec)
     wrong_policy = ConstantPolicy(action_spec, wrong_time_step_spec, 9)
     sub_policies.append(wrong_policy)
     weights = [0, 0, 0.2, 0, 0, -0.3, 0, 0, 0.5, 0]
     with self.assertRaisesRegexp(AssertionError,
                                  'Inconsistent time step specs'):
         mixture_policy.MixturePolicy(weights, sub_policies)
Esempio n. 4
0
    def __init__(self,
                 mixture_distribution: types.Distribution,
                 agents: Sequence[tf_agent.TFAgent],
                 name: Optional[Text] = None):
        """Initializes an instance of `MixtureAgent`.

    Args:
      mixture_distribution: An instance of `tfd.Categorical` distribution. This
        distribution is used to draw sub-policies by the mixture policy. The
        parameters of the distribution is trained by the mixture agent.
      agents: List of instances of TF-Agents bandit agents. These agents will be
        trained and used to select actions. The length of this list should match
        that of `mixture_weights`.
      name: The name of this instance of `MixtureAgent`.
    """
        tf.Module.__init__(self, name=name)
        time_step_spec = agents[0].time_step_spec
        action_spec = agents[0].action_spec
        self._original_info_spec = agents[0].policy.info_spec
        error_message = None
        for agent in agents[1:]:
            if action_spec != agent.action_spec:
                error_message = 'Inconsistent action specs.'
            if time_step_spec != agent.time_step_spec:
                error_message = 'Inconsistent time step specs.'
            if self._original_info_spec != agent.policy.info_spec:
                error_message = 'Inconsistent info specs.'
        if error_message is not None:
            raise ValueError(error_message)
        self._agents = agents
        self._num_agents = len(agents)
        self._mixture_distribution = mixture_distribution
        policies = [agent.collect_policy for agent in agents]
        policy = mixture_policy.MixturePolicy(mixture_distribution, policies)
        super(MixtureAgent, self).__init__(time_step_spec,
                                           action_spec,
                                           policy,
                                           policy,
                                           train_sequence_length=None)
        self._as_trajectory = data_converter.AsTrajectory(self.data_context,
                                                          sequence_length=None)
 def testMixturePolicyChoices(self):
     context_dim = 34
     observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32)
     time_step_spec = ts.time_step_spec(observation_spec)
     action_spec = tensor_spec.BoundedTensorSpec(shape=(),
                                                 dtype=tf.int32,
                                                 minimum=0,
                                                 maximum=9,
                                                 name='action')
     sub_policies = [
         ConstantPolicy(action_spec, time_step_spec, i) for i in range(10)
     ]
     weights = [0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.5, 0]
     dist = tfd.Categorical(probs=weights)
     policy = mixture_policy.MixturePolicy(dist, sub_policies)
     batch_size = 15
     time_step = ts.TimeStep(
         tf.constant(ts.StepType.FIRST,
                     dtype=tf.int32,
                     shape=[batch_size],
                     name='step_type'),
         tf.constant(0.0,
                     dtype=tf.float32,
                     shape=[batch_size],
                     name='reward'),
         tf.constant(1.0,
                     dtype=tf.float32,
                     shape=[batch_size],
                     name='discount'),
         tf.constant(list(range(batch_size * context_dim)),
                     dtype=tf.float32,
                     shape=[batch_size, context_dim],
                     name='observation'))
     action_step = policy.action(time_step)
     actions, infos = self.evaluate([action_step.action, action_step.info])
     tf.nest.assert_same_structure(policy.info_spec, infos)
     self.assertAllEqual(actions.shape, [batch_size])
     self.assertAllInSet(actions, [2, 5, 8])
Esempio n. 6
0
    def __init__(self, mixture_weights, agents, name=None):
        """Initializes an instance of `StaticMixtureAgent`.

    Args:
      mixture_weights: (list of floats) The (possibly unnormalized) probability
        distribution based on which the agent chooses the sub-agents.
      agents: List of instances of TF-Agents bandit agents. These agents will be
        trained and used to select actions. The length of this list should match
        that of `mixture_weights`.
      name: The name of this instance of `StaticMixtureAgent`.
    """
        tf.Module.__init__(self, name=name)
        time_step_spec = agents[0].time_step_spec
        action_spec = agents[0].action_spec
        self._original_info_spec = agents[0].policy.info_spec
        error_message = None
        for agent in agents[1:]:
            if action_spec != agent.action_spec:
                error_message = 'Inconsistent action specs.'
            if time_step_spec != agent.time_step_spec:
                error_message = 'Inconsistent time step specs.'
            if self._original_info_spec != agent.policy.info_spec:
                error_message = 'Inconsistent info specs.'
        if len(mixture_weights) != len(agents):
            error_message = '`mixture_weights` and `agents` must have equal length.'
        if error_message is not None:
            raise ValueError(error_message)
        self._agents = agents
        self._num_agents = len(agents)
        self._mixture_weights = mixture_weights
        policies = [agent.collect_policy for agent in agents]
        policy = mixture_policy.MixturePolicy(mixture_weights, policies)
        super(StaticMixtureAgent, self).__init__(time_step_spec,
                                                 action_spec,
                                                 policy,
                                                 policy,
                                                 train_sequence_length=None)