Beispiel #1
0
def main(unused_argv):
    tf.compat.v1.enable_v2_behavior()  # The trainer only runs with V2 enabled.

    with tf.device('/CPU:0'):  # due to b/128333994
        observation_shape = [CONTEXT_DIM]
        overall_shape = [BATCH_SIZE] + observation_shape
        observation_distribution = tfd.Normal(loc=tf.zeros(overall_shape),
                                              scale=tf.ones(overall_shape))
        action_shape = [NUM_ACTIONS]
        observation_to_reward_shape = observation_shape + action_shape
        observation_to_reward_distribution = tfd.Normal(
            loc=tf.zeros(observation_to_reward_shape),
            scale=tf.ones(observation_to_reward_shape))
        drift_distribution = tfd.Normal(loc=DRIFT_MEAN, scale=DRIFT_VARIANCE)
        additive_reward_distribution = tfd.Normal(
            loc=tf.zeros(action_shape),
            scale=(REWARD_NOISE_VARIANCE * tf.ones(action_shape)))
        environment_dynamics = dle.DriftingLinearDynamics(
            observation_distribution, observation_to_reward_distribution,
            drift_distribution, additive_reward_distribution)
        environment = nse.NonStationaryStochasticEnvironment(
            environment_dynamics)

        if FLAGS.agent == 'LinUCB':
            agent = lin_ucb_agent.LinearUCBAgent(
                time_step_spec=environment.time_step_spec(),
                action_spec=environment.action_spec(),
                alpha=AGENT_ALPHA,
                gamma=0.95,
                emit_log_probability=False,
                dtype=tf.float32)
        elif FLAGS.agent == 'LinTS':
            agent = lin_ts_agent.LinearThompsonSamplingAgent(
                time_step_spec=environment.time_step_spec(),
                action_spec=environment.action_spec(),
                alpha=AGENT_ALPHA,
                gamma=0.95,
                dtype=tf.float32)

        regret_metric = tf_bandit_metrics.RegretMetric(
            environment.environment_dynamics.compute_optimal_reward)
        suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric(
            environment.environment_dynamics.compute_optimal_action)

        trainer.train(
            root_dir=FLAGS.root_dir,
            agent=agent,
            environment=environment,
            training_loops=TRAINING_LOOPS,
            steps_per_loop=STEPS_PER_LOOP,
            additional_metrics=[regret_metric, suboptimal_arms_metric])
    def testObservationAndRewardsVary(self):
        """Ensure that observations and rewards change in consecutive calls."""
        dynamics = DummyDynamics()
        env = nsse.NonStationaryStochasticEnvironment(dynamics)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        env_time = env._env_time
        observation_samples = []
        reward_samples = []

        if tf.executing_eagerly():
            for t in range(0, 10):
                ts = env.reset()
                observation = ts.observation
                reward = env.step(tf.zeros([2])).reward

                [observation_sample, reward_sample, env_time_sample
                 ] = self.evaluate([observation, reward, env_time])
                observation_samples.append(observation_sample)
                reward_samples.append(reward_sample)
                self.assertEqual(env_time_sample,
                                 (t + 1) * dynamics.batch_size)

        else:
            ts = env.reset()
            observation = ts.observation
            reward = env.step(tf.zeros([2])).reward

            for t in range(0, 10):
                # The order of evaluations below matters. We first compute observation
                # batch, then the reward, and finally the env_time tensor. Joining the
                # evaluations in a single call does not guarantee the right order.

                observation_sample = self.evaluate(observation)
                reward_sample = self.evaluate(reward)
                env_time_sample = self.evaluate(env_time)
                observation_samples.append(observation_sample)
                reward_samples.append(reward_sample)
                self.assertEqual(env_time_sample,
                                 (t + 1) * dynamics.batch_size)

        for t in range(0, 10):
            t_b = t * dynamics.batch_size
            self.assertAllClose(observation_samples[t],
                                [[1.0 + t_b, 2.0 + t_b, 3.0 + t_b],
                                 [0.0 + t_b, 4.0 + t_b, 5.0 + t_b]])
            self.assertAllClose(reward_samples[t], [1, 0])