def main(unused_argv): tf.compat.v1.enable_v2_behavior() # The trainer only runs with V2 enabled. means = [0.1, 0.2, 0.3, 0.45, 0.5] env = bern_env.BernoulliPyEnvironment(means=means, batch_size=BATCH_SIZE) environment = tf_py_environment.TFPyEnvironment(env) def optimal_reward_fn(unused_observation): return np.max(means) def optimal_action_fn(unused_observation): return np.int32(np.argmax(means)) if FLAGS.agent == 'BernTS': agent = bern_ts_agent.BernoulliThompsonSamplingAgent( time_step_spec=environment.time_step_spec(), action_spec=environment.action_spec(), dtype=tf.float64, batch_size=BATCH_SIZE) else: raise ValueError('Only BernoulliTS is supported for now.') regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_fn) suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric( optimal_action_fn) trainer.train(root_dir=FLAGS.root_dir, agent=agent, environment=environment, training_loops=TRAINING_LOOPS, steps_per_loop=STEPS_PER_LOOP, additional_metrics=[regret_metric, suboptimal_arms_metric], save_policy=False)
def testInitializeAgent(self): agent = bern_ts_agent.BernoulliThompsonSamplingAgent( self._time_step_spec, self._action_spec) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op))
def testPolicy(self): agent = bern_ts_agent.BernoulliThompsonSamplingAgent( self._time_step_spec, self._action_spec, batch_size=2) observations = tf.constant([[1, 1]], dtype=tf.float32) time_steps = ts.restart(observations, batch_size=2) policy = agent.policy action_step = policy.action(time_steps) # Batch size 2. self.assertAllEqual([2], action_step.action.shape) self.evaluate(tf.compat.v1.initialize_all_variables()) self.assertEqual(action_step.action.shape.as_list(), [2]) self.assertEqual(action_step.action.dtype, tf.int32)
def testTrainAgentWithMask(self): time_step_spec = ts.time_step_spec((tensor_spec.TensorSpec([], tf.float32), tensor_spec.TensorSpec([3], tf.int32))) agent = bern_ts_agent.BernoulliThompsonSamplingAgent( time_step_spec, self._action_spec, batch_size=2) observations = (np.array([1, 1], dtype=np.float32), np.array([[0, 0, 1], [0, 0, 1]], dtype=np.int32)) actions = np.array([0, 1], dtype=np.int32) rewards = np.array([1.0, 0.0], dtype=np.float32) initial_step, final_step = _get_initial_and_final_steps_with_action_mask( observations, rewards) action_step = _get_action_step(actions) experience = _get_experience(initial_step, action_step, final_step) loss, _ = agent.train(experience, None) self.evaluate(tf.compat.v1.initialize_all_variables()) self.assertAllClose(self.evaluate(loss), -1.0)
def testTrainAgent(self): observations = np.array([[1, 1]], dtype=np.float32) actions = np.array([0, 1], dtype=np.int32) rewards = np.array([0.0, 1.0], dtype=np.float32) initial_step, final_step = _get_initial_and_final_steps( observations, rewards) action_step = _get_action_step(actions) experience = _get_experience(initial_step, action_step, final_step) agent = bern_ts_agent.BernoulliThompsonSamplingAgent( self._time_step_spec, self._action_spec, batch_size=2) init_op = agent.initialize() if not tf.executing_eagerly(): with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertIsNone(sess.run(init_op)) loss, _ = agent._train(experience, weights=None) self.evaluate(tf.compat.v1.initialize_all_variables()) # The loss is -sum(rewards). self.assertAllClose(self.evaluate(loss), -1.0)
def testCreateAgent(self): agent = bern_ts_agent.BernoulliThompsonSamplingAgent( self._time_step_spec, self._action_spec) self.assertIsNotNone(agent.policy)