Ejemplo n.º 1
0
def main(unused_argv):
    tf.compat.v1.enable_v2_behavior()  # The trainer only runs with V2 enabled.

    means = [0.1, 0.2, 0.3, 0.45, 0.5]
    env = bern_env.BernoulliPyEnvironment(means=means, batch_size=BATCH_SIZE)
    environment = tf_py_environment.TFPyEnvironment(env)

    def optimal_reward_fn(unused_observation):
        return np.max(means)

    def optimal_action_fn(unused_observation):
        return np.int32(np.argmax(means))

    if FLAGS.agent == 'BernTS':
        agent = bern_ts_agent.BernoulliThompsonSamplingAgent(
            time_step_spec=environment.time_step_spec(),
            action_spec=environment.action_spec(),
            dtype=tf.float64,
            batch_size=BATCH_SIZE)
    else:
        raise ValueError('Only BernoulliTS is supported for now.')

    regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_fn)
    suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric(
        optimal_action_fn)

    trainer.train(root_dir=FLAGS.root_dir,
                  agent=agent,
                  environment=environment,
                  training_loops=TRAINING_LOOPS,
                  steps_per_loop=STEPS_PER_LOOP,
                  additional_metrics=[regret_metric, suboptimal_arms_metric],
                  save_policy=False)
Ejemplo n.º 2
0
 def testInitializeAgent(self):
   agent = bern_ts_agent.BernoulliThompsonSamplingAgent(
       self._time_step_spec,
       self._action_spec)
   init_op = agent.initialize()
   if not tf.executing_eagerly():
     with self.cached_session() as sess:
       common.initialize_uninitialized_variables(sess)
       self.assertIsNone(sess.run(init_op))
Ejemplo n.º 3
0
 def testPolicy(self):
   agent = bern_ts_agent.BernoulliThompsonSamplingAgent(
       self._time_step_spec,
       self._action_spec,
       batch_size=2)
   observations = tf.constant([[1, 1]], dtype=tf.float32)
   time_steps = ts.restart(observations, batch_size=2)
   policy = agent.policy
   action_step = policy.action(time_steps)
   # Batch size 2.
   self.assertAllEqual([2], action_step.action.shape)
   self.evaluate(tf.compat.v1.initialize_all_variables())
   self.assertEqual(action_step.action.shape.as_list(), [2])
   self.assertEqual(action_step.action.dtype, tf.int32)
Ejemplo n.º 4
0
 def testTrainAgentWithMask(self):
   time_step_spec = ts.time_step_spec((tensor_spec.TensorSpec([], tf.float32),
                                       tensor_spec.TensorSpec([3], tf.int32)))
   agent = bern_ts_agent.BernoulliThompsonSamplingAgent(
       time_step_spec,
       self._action_spec,
       batch_size=2)
   observations = (np.array([1, 1], dtype=np.float32),
                   np.array([[0, 0, 1], [0, 0, 1]], dtype=np.int32))
   actions = np.array([0, 1], dtype=np.int32)
   rewards = np.array([1.0, 0.0], dtype=np.float32)
   initial_step, final_step = _get_initial_and_final_steps_with_action_mask(
       observations, rewards)
   action_step = _get_action_step(actions)
   experience = _get_experience(initial_step, action_step, final_step)
   loss, _ = agent.train(experience, None)
   self.evaluate(tf.compat.v1.initialize_all_variables())
   self.assertAllClose(self.evaluate(loss), -1.0)
Ejemplo n.º 5
0
  def testTrainAgent(self):
    observations = np.array([[1, 1]], dtype=np.float32)
    actions = np.array([0, 1], dtype=np.int32)
    rewards = np.array([0.0, 1.0], dtype=np.float32)
    initial_step, final_step = _get_initial_and_final_steps(
        observations, rewards)
    action_step = _get_action_step(actions)
    experience = _get_experience(initial_step, action_step, final_step)

    agent = bern_ts_agent.BernoulliThompsonSamplingAgent(
        self._time_step_spec,
        self._action_spec,
        batch_size=2)
    init_op = agent.initialize()
    if not tf.executing_eagerly():
      with self.cached_session() as sess:
        common.initialize_uninitialized_variables(sess)
        self.assertIsNone(sess.run(init_op))
    loss, _ = agent._train(experience, weights=None)
    self.evaluate(tf.compat.v1.initialize_all_variables())
    # The loss is -sum(rewards).
    self.assertAllClose(self.evaluate(loss), -1.0)
Ejemplo n.º 6
0
 def testCreateAgent(self):
   agent = bern_ts_agent.BernoulliThompsonSamplingAgent(
       self._time_step_spec,
       self._action_spec)
   self.assertIsNotNone(agent.policy)