def testMixturePolicyNegativeProb(self): context_dim = 11 observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.BoundedTensorSpec(shape=(), dtype=tf.int32, minimum=0, maximum=9, name='action') sub_policies = [ ConstantPolicy(action_spec, time_step_spec, i) for i in range(10) ] weights = [0, 0, 0.2, 0, 0, -0.3, 0, 0, 0.5, 0] policy = mixture_policy.MixturePolicy(weights, sub_policies) batch_size = 15 time_step = ts.TimeStep( tf.constant(ts.StepType.FIRST, dtype=tf.int32, shape=[batch_size], name='step_type'), tf.constant(0.0, dtype=tf.float32, shape=[batch_size], name='reward'), tf.constant(1.0, dtype=tf.float32, shape=[batch_size], name='discount'), tf.constant(list(range(batch_size * context_dim)), dtype=tf.float32, shape=[batch_size, context_dim], name='observation')) with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, 'Negative probability'): policy.action(time_step)
def testMixturePolicyDynamicBatchSize(self): context_dim = 35 observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.BoundedTensorSpec(shape=(), dtype=tf.int32, minimum=0, maximum=9, name='action') sub_policies = [ ConstantPolicy(action_spec, time_step_spec, i) for i in range(10) ] weights = [0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.5, 0] dist = tfd.Categorical(probs=weights) policy = mixture_policy.MixturePolicy(dist, sub_policies) batch_size = tf.random.uniform(shape=(), minval=10, maxval=15, dtype=tf.int32) time_step = ts.TimeStep( tf.fill(tf.expand_dims(batch_size, axis=0), ts.StepType.FIRST, name='step_type'), tf.zeros(shape=[batch_size], dtype=tf.float32, name='reward'), tf.ones(shape=[batch_size], dtype=tf.float32, name='discount'), tf.reshape(tf.range(tf.cast(batch_size * context_dim, dtype=tf.float32), dtype=tf.float32), shape=[-1, context_dim], name='observation')) action_step = policy.action(time_step) actions, bsize = self.evaluate([action_step.action, batch_size]) self.assertAllEqual(actions.shape, [bsize]) self.assertAllInSet(actions, [2, 5, 8]) saver = policy_saver.PolicySaver(policy) location = os.path.join(self.get_temp_dir(), 'saved_policy') saver.save(location) loaded_policy = tf.compat.v2.saved_model.load(location) new_batch_size = 3 new_time_step = ts.TimeStep( tf.fill(tf.expand_dims(new_batch_size, axis=0), ts.StepType.FIRST, name='step_type'), tf.zeros(shape=[new_batch_size], dtype=tf.float32, name='reward'), tf.ones(shape=[new_batch_size], dtype=tf.float32, name='discount'), tf.reshape(tf.range(tf.cast(new_batch_size * context_dim, dtype=tf.float32), dtype=tf.float32), shape=[-1, context_dim], name='observation')) new_action = self.evaluate(loaded_policy.action(new_time_step).action) self.assertAllEqual(new_action.shape, [new_batch_size]) self.assertAllInSet(new_action, [2, 5, 8])
def testMixturePolicyInconsistentSpecs(self): context_dim = 11 observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.BoundedTensorSpec(shape=(), dtype=tf.int32, minimum=0, maximum=9, name='action') sub_policies = [ ConstantPolicy(action_spec, time_step_spec, i) for i in range(9) ] wrong_obs_spec = tensor_spec.TensorSpec([context_dim + 1], tf.float32) wrong_time_step_spec = ts.time_step_spec(wrong_obs_spec) wrong_policy = ConstantPolicy(action_spec, wrong_time_step_spec, 9) sub_policies.append(wrong_policy) weights = [0, 0, 0.2, 0, 0, -0.3, 0, 0, 0.5, 0] with self.assertRaisesRegexp(AssertionError, 'Inconsistent time step specs'): mixture_policy.MixturePolicy(weights, sub_policies)
def __init__(self, mixture_distribution: types.Distribution, agents: Sequence[tf_agent.TFAgent], name: Optional[Text] = None): """Initializes an instance of `MixtureAgent`. Args: mixture_distribution: An instance of `tfd.Categorical` distribution. This distribution is used to draw sub-policies by the mixture policy. The parameters of the distribution is trained by the mixture agent. agents: List of instances of TF-Agents bandit agents. These agents will be trained and used to select actions. The length of this list should match that of `mixture_weights`. name: The name of this instance of `MixtureAgent`. """ tf.Module.__init__(self, name=name) time_step_spec = agents[0].time_step_spec action_spec = agents[0].action_spec self._original_info_spec = agents[0].policy.info_spec error_message = None for agent in agents[1:]: if action_spec != agent.action_spec: error_message = 'Inconsistent action specs.' if time_step_spec != agent.time_step_spec: error_message = 'Inconsistent time step specs.' if self._original_info_spec != agent.policy.info_spec: error_message = 'Inconsistent info specs.' if error_message is not None: raise ValueError(error_message) self._agents = agents self._num_agents = len(agents) self._mixture_distribution = mixture_distribution policies = [agent.collect_policy for agent in agents] policy = mixture_policy.MixturePolicy(mixture_distribution, policies) super(MixtureAgent, self).__init__(time_step_spec, action_spec, policy, policy, train_sequence_length=None) self._as_trajectory = data_converter.AsTrajectory(self.data_context, sequence_length=None)
def testMixturePolicyChoices(self): context_dim = 34 observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.BoundedTensorSpec(shape=(), dtype=tf.int32, minimum=0, maximum=9, name='action') sub_policies = [ ConstantPolicy(action_spec, time_step_spec, i) for i in range(10) ] weights = [0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.5, 0] dist = tfd.Categorical(probs=weights) policy = mixture_policy.MixturePolicy(dist, sub_policies) batch_size = 15 time_step = ts.TimeStep( tf.constant(ts.StepType.FIRST, dtype=tf.int32, shape=[batch_size], name='step_type'), tf.constant(0.0, dtype=tf.float32, shape=[batch_size], name='reward'), tf.constant(1.0, dtype=tf.float32, shape=[batch_size], name='discount'), tf.constant(list(range(batch_size * context_dim)), dtype=tf.float32, shape=[batch_size, context_dim], name='observation')) action_step = policy.action(time_step) actions, infos = self.evaluate([action_step.action, action_step.info]) tf.nest.assert_same_structure(policy.info_spec, infos) self.assertAllEqual(actions.shape, [batch_size]) self.assertAllInSet(actions, [2, 5, 8])
def __init__(self, mixture_weights, agents, name=None): """Initializes an instance of `StaticMixtureAgent`. Args: mixture_weights: (list of floats) The (possibly unnormalized) probability distribution based on which the agent chooses the sub-agents. agents: List of instances of TF-Agents bandit agents. These agents will be trained and used to select actions. The length of this list should match that of `mixture_weights`. name: The name of this instance of `StaticMixtureAgent`. """ tf.Module.__init__(self, name=name) time_step_spec = agents[0].time_step_spec action_spec = agents[0].action_spec self._original_info_spec = agents[0].policy.info_spec error_message = None for agent in agents[1:]: if action_spec != agent.action_spec: error_message = 'Inconsistent action specs.' if time_step_spec != agent.time_step_spec: error_message = 'Inconsistent time step specs.' if self._original_info_spec != agent.policy.info_spec: error_message = 'Inconsistent info specs.' if len(mixture_weights) != len(agents): error_message = '`mixture_weights` and `agents` must have equal length.' if error_message is not None: raise ValueError(error_message) self._agents = agents self._num_agents = len(agents) self._mixture_weights = mixture_weights policies = [agent.collect_policy for agent in agents] policy = mixture_policy.MixturePolicy(mixture_weights, policies) super(StaticMixtureAgent, self).__init__(time_step_spec, action_spec, policy, policy, train_sequence_length=None)