Ejemplo n.º 1
0
    def testShuffle(self):
        """Test that dataset is being shuffled when asked."""
        # Reward of 1 is given if action == (context % 3)
        context = tf.reshape(tf.range(128), shape=[128, 1])
        labels = tf.math.mod(context, 3)
        batch_size = 32
        dataset = (tf.data.Dataset.from_tensor_slices(
            (context, labels)).repeat().shuffle(4 * batch_size))
        reward_distribution = deterministic_reward_distribution(tf.eye(3))

        # Note - shuffle should hapen *first* in call chain, so this
        # test will fail if shuffle is called e.g. after batch or prefetch.
        dataset.shuffle = mock.Mock(spec=dataset.shuffle,
                                    side_effect=dataset.shuffle)
        ce.ClassificationBanditEnvironment(dataset, reward_distribution,
                                           batch_size)
        dataset.shuffle.assert_not_called()
        ce.ClassificationBanditEnvironment(dataset,
                                           reward_distribution,
                                           batch_size,
                                           shuffle_buffer_size=3,
                                           seed=7)
        dataset.shuffle.assert_called_with(buffer_size=3,
                                           reshuffle_each_iteration=True,
                                           seed=7)
Ejemplo n.º 2
0
    def testReturnsCorrectRewards(self):
        """Test that rewards are being returned correctly for a simple case."""
        # Reward of 1 is given if action == (context % 3)
        context = tf.reshape(tf.range(128), shape=[128, 1])
        labels = tf.math.mod(context, 3)
        batch_size = 32
        dataset = (tf.data.Dataset.from_tensor_slices(
            (context, labels)).repeat().shuffle(4 * batch_size))
        reward_distribution = deterministic_reward_distribution(tf.eye(3))
        env = ce.ClassificationBanditEnvironment(dataset, reward_distribution,
                                                 batch_size)
        self.evaluate(tf.compat.v1.global_variables_initializer())
        for _ in range(10):
            # Take the 'correct' action
            observation = env.reset().observation
            action = tf.math.mod(observation, 3)
            reward = env.step(action).reward
            np.testing.assert_almost_equal(self.evaluate(reward),
                                           self.evaluate(tf.ones_like(reward)))

        for _ in range(10):
            # Take the 'incorrect' action
            observation = env.reset().observation
            action = tf.math.mod(observation + 1, 3)
            reward = env.step(action).reward
            np.testing.assert_almost_equal(
                self.evaluate(reward), self.evaluate(tf.zeros_like(reward)))
def main(unused_argv):
    tf.compat.v1.enable_v2_behavior()  # The trainer only runs with V2 enabled.

    with tf.device('/CPU:0'):  # due to b/128333994

        covertype_dataset = dataset_utilities.convert_covertype_dataset(
            FLAGS.covertype_csv)
        covertype_reward_distribution = tfd.Independent(
            tfd.Deterministic(tf.eye(7)), reinterpreted_batch_ndims=2)
        environment = ce.ClassificationBanditEnvironment(
            covertype_dataset, covertype_reward_distribution, BATCH_SIZE)

        optimal_reward_fn = functools.partial(
            env_util.compute_optimal_reward_with_classification_environment,
            environment=environment)

        optimal_action_fn = functools.partial(
            env_util.compute_optimal_action_with_classification_environment,
            environment=environment)

        if FLAGS.agent == 'LinUCB':
            agent = lin_ucb_agent.LinearUCBAgent(
                time_step_spec=environment.time_step_spec(),
                action_spec=environment.action_spec(),
                alpha=AGENT_ALPHA,
                emit_log_probability=False,
                dtype=tf.float32)
        elif FLAGS.agent == 'LinTS':
            agent = lin_ts_agent.LinearThompsonSamplingAgent(
                time_step_spec=environment.time_step_spec(),
                action_spec=environment.action_spec(),
                alpha=AGENT_ALPHA,
                dtype=tf.float32)
        elif FLAGS.agent == 'epsGreedy':
            network = q_network.QNetwork(
                input_tensor_spec=environment.time_step_spec().observation,
                action_spec=environment.action_spec(),
                fc_layer_params=LAYERS)
            agent = eps_greedy_agent.NeuralEpsilonGreedyAgent(
                time_step_spec=environment.time_step_spec(),
                action_spec=environment.action_spec(),
                reward_network=network,
                optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=LR),
                epsilon=EPSILON)

        regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_fn)
        suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric(
            optimal_action_fn)

        trainer.train(
            root_dir=FLAGS.root_dir,
            agent=agent,
            environment=environment,
            training_loops=TRAINING_LOOPS,
            steps_per_loop=STEPS_PER_LOOP,
            additional_metrics=[regret_metric, suboptimal_arms_metric])
Ejemplo n.º 4
0
    def testPrefetch(self, mock_dataset_iterator):
        """Test that dataset is being prefetched when asked."""
        mock_dataset_iterator.return_value = 'mock_iterator_result'
        # Reward of 1 is given if action == (context % 3)
        context = tf.reshape(tf.range(128), shape=[128, 1])
        labels = tf.math.mod(context, 3)
        batch_size = 32
        dataset = tf.data.Dataset.from_tensor_slices((context, labels))
        reward_distribution = deterministic_reward_distribution(tf.eye(3))

        # Operation order should be batch() then prefetch(), have to jump
        # through a couple hoops to get this sequence tested correctly.

        # Save dataset.prefetch in temp mock_prefetch, return batched dataset to
        # make down-stream logic work correctly with batch dimensions.
        batched_dataset = dataset.batch(batch_size)
        mock_prefetch = mock.Mock(spec=dataset.prefetch,
                                  return_value=batched_dataset)
        # Replace dataset.batch with mock batch that returns original dataset,
        # in order to make mocking out it's prefetch call easier.
        dataset.batch = mock.Mock(spec=batched_dataset,
                                  return_value=batched_dataset)
        # Replace dataset.prefetch with mock_prefetch.
        batched_dataset.prefetch = mock_prefetch
        env = ce.ClassificationBanditEnvironment(dataset,
                                                 reward_distribution,
                                                 batch_size,
                                                 repeat_dataset=False)
        dataset.batch.assert_called_with(batch_size, drop_remainder=True)
        batched_dataset.prefetch.assert_not_called()
        mock_dataset_iterator.assert_called_with(batched_dataset)
        self.assertEqual(env._data_iterator, 'mock_iterator_result')
        env = ce.ClassificationBanditEnvironment(dataset,
                                                 reward_distribution,
                                                 batch_size,
                                                 repeat_dataset=False,
                                                 prefetch_size=3)
        dataset.batch.assert_called_with(batch_size, drop_remainder=True)
        batched_dataset.prefetch.assert_called_with(3)
        mock_dataset_iterator.assert_called_with(batched_dataset)
        self.assertEqual(env._data_iterator, 'mock_iterator_result')
Ejemplo n.º 5
0
def main(unused_argv):
    tf.compat.v1.enable_v2_behavior()  # The trainer only runs with V2 enabled.

    with tf.device('/CPU:0'):  # due to b/128333994

        mushroom_reward_distribution = (
            dataset_utilities.mushroom_reward_distribution(
                r_noeat=0.0,
                r_eat_safe=5.0,
                r_eat_poison_bad=-35.0,
                r_eat_poison_good=5.0,
                prob_poison_bad=0.5))
        mushroom_dataset = (
            dataset_utilities.convert_mushroom_csv_to_tf_dataset(
                FLAGS.mushroom_csv))
        environment = ce.ClassificationBanditEnvironment(
            mushroom_dataset, mushroom_reward_distribution, BATCH_SIZE)

        optimal_reward_fn = functools.partial(
            env_util.compute_optimal_reward_with_classification_environment,
            environment=environment)

        optimal_action_fn = functools.partial(
            env_util.compute_optimal_action_with_classification_environment,
            environment=environment)

        if FLAGS.agent == 'LinUCB':
            agent = lin_ucb_agent.LinearUCBAgent(
                time_step_spec=environment.time_step_spec(),
                action_spec=environment.action_spec(),
                alpha=AGENT_ALPHA,
                gamma=0.95,
                emit_log_probability=False,
                dtype=tf.float32)
        elif FLAGS.agent == 'LinTS':
            agent = lin_ts_agent.LinearThompsonSamplingAgent(
                time_step_spec=environment.time_step_spec(),
                action_spec=environment.action_spec(),
                alpha=AGENT_ALPHA,
                gamma=0.95,
                dtype=tf.float32)

        regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_fn)
        suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric(
            optimal_action_fn)

        trainer.train(
            root_dir=FLAGS.root_dir,
            agent=agent,
            environment=environment,
            training_loops=TRAINING_LOOPS,
            steps_per_loop=STEPS_PER_LOOP,
            additional_metrics=[regret_metric, suboptimal_arms_metric])
Ejemplo n.º 6
0
 def testObservationShapeAndValue(self, context, labels, batch_size):
     """Test that observations have correct shape and values from `context`."""
     dataset = (tf.data.Dataset.from_tensor_slices(
         (context, labels)).repeat().shuffle(4 * batch_size))
     # Rewards of 1. is given when action == label
     reward_distribution = deterministic_reward_distribution(
         tf.eye(len(set(labels))))
     env = ce.ClassificationBanditEnvironment(dataset, reward_distribution,
                                              batch_size)
     expected_observation_shape = [batch_size] + list(context.shape[1:])
     self.evaluate(tf.compat.v1.global_variables_initializer())
     for _ in range(100):
         observation = self.evaluate(env.reset().observation)
         np.testing.assert_array_equal(observation.shape,
                                       expected_observation_shape)
         for o in observation:
             self.assertIn(o, context)
Ejemplo n.º 7
0
  def testPreviousLabelIsSetCorrectly(self):
    """Test that the previous label is set correctly for a simple case."""
    # Reward of 1 is given if action == (context % 3)
    context = tf.reshape(tf.range(128), shape=[128, 1])
    labels = tf.math.mod(context, 3)
    batch_size = 4
    dataset = (
        tf.data.Dataset.from_tensor_slices(
            (context, labels)).repeat().shuffle(4 * batch_size))
    reward_distribution = deterministic_reward_distribution(tf.eye(3))
    env = ce.ClassificationBanditEnvironment(
        dataset, reward_distribution, batch_size)
    self.evaluate(tf.compat.v1.global_variables_initializer())

    time_step = env.reset()
    time_step_label = tf.squeeze(tf.math.mod(time_step.observation, 3))
    action = tf.math.mod(time_step.observation, 3)
    next_time_step = env.step(action)
    next_time_step_label = tf.squeeze(
        tf.math.mod(next_time_step.observation, 3))

    if tf.executing_eagerly():
      np.testing.assert_almost_equal(
          self.evaluate(time_step_label),
          self.evaluate(env._previous_label))
      np.testing.assert_almost_equal(
          self.evaluate(next_time_step_label),
          self.evaluate(env._current_label))
    else:
      with self.cached_session() as sess:
        time_step_label_value, next_time_step_label_value = (
            sess.run([time_step_label, next_time_step_label]))

        previous_label_value = self.evaluate(env._previous_label)
        np.testing.assert_almost_equal(
            time_step_label_value, previous_label_value)
        current_label_value = self.evaluate(env._current_label)
        np.testing.assert_almost_equal(
            next_time_step_label_value,
            current_label_value)