def testTrainerExportsCheckpoints(self,
                                   num_actions,
                                   observation_shape,
                                   action_shape,
                                   batch_size,
                                   training_loops,
                                   steps_per_loop,
                                   learning_rate):
   """Exercises trainer code, checks that expected checkpoints are exported."""
   root_dir = tempfile.mkdtemp(dir=os.getenv('TEST_TMPDIR'))
   environment = get_bounded_reward_random_environment(
       observation_shape, action_shape, batch_size, num_actions)
   agent = exp3_agent.Exp3Agent(
       learning_rate=learning_rate,
       time_step_spec=environment.time_step_spec(),
       action_spec=environment.action_spec())
   for i in range(1, 4):
     trainer.train(
         root_dir=root_dir,
         agent=agent,
         environment=environment,
         training_loops=training_loops,
         steps_per_loop=steps_per_loop)
     latest_checkpoint = tf.train.latest_checkpoint(root_dir)
     expected_checkpoint_regex = '.*-{}'.format(i * training_loops)
     self.assertRegex(latest_checkpoint, expected_checkpoint_regex)
 def testInitializeAgent(self, observation_shape, num_actions,
                         learning_rate):
     time_step_spec = time_step.time_step_spec(
         tensor_spec.TensorSpec(observation_shape, tf.float32))
     action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32,
                                                 shape=(),
                                                 minimum=0,
                                                 maximum=num_actions - 1)
     agent = exp3_agent.Exp3Agent(time_step_spec=time_step_spec,
                                  action_spec=action_spec,
                                  learning_rate=learning_rate)
     self.evaluate(agent.initialize())
Exemple #3
0
    def testTrainerTF1ExportsCheckpoints(self, num_actions, observation_shape,
                                         action_shape, batch_size,
                                         training_loops, steps_per_loop,
                                         learning_rate):
        """Tests TF1 trainer code, checks that expected checkpoints are exported."""
        root_dir = tempfile.mkdtemp(dir=os.getenv('TEST_TMPDIR'))
        environment = get_bounded_reward_random_environment(
            observation_shape, action_shape, batch_size, num_actions)
        agent = exp3_agent.Exp3Agent(
            learning_rate=learning_rate,
            time_step_spec=environment.time_step_spec(),
            action_spec=environment.action_spec())

        trainer.train(root_dir, agent, environment, training_loops,
                      steps_per_loop)
        latest_checkpoint = tf.train.latest_checkpoint(
            os.path.join(root_dir, 'train'))
        expected_checkpoint_regex = '.*ckpt-{}'.format(
            training_loops * batch_size * steps_per_loop)
        self.assertRegex(latest_checkpoint, expected_checkpoint_regex)
    def testExp3Update(self, observation_shape, num_actions, action, log_prob,
                       reward, learning_rate):
        """Check EXP3 updates for specified actions and rewards."""

        # Compute expected update for each action.
        expected_update_value = exp3_agent.exp3_update_value(reward, log_prob)
        expected_update = np.zeros(num_actions)
        for a, u in zip(action, self.evaluate(expected_update_value)):
            expected_update[a] += u

        # Construct a `Trajectory` for the given action, log prob and reward.
        time_step_spec = time_step.time_step_spec(
            tensor_spec.TensorSpec(observation_shape, tf.float32))
        action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32,
                                                    shape=(),
                                                    minimum=0,
                                                    maximum=num_actions - 1)
        initial_step, final_step = _get_initial_and_final_steps(
            observation_shape, reward)
        action_step = _get_action_step(action, log_prob)
        experience = _get_experience(initial_step, action_step, final_step)

        # Construct an agent and perform the update. Record initial and final
        # weights.
        agent = exp3_agent.Exp3Agent(time_step_spec=time_step_spec,
                                     action_spec=action_spec,
                                     learning_rate=learning_rate)
        self.evaluate(agent.initialize())
        initial_weights = self.evaluate(agent.weights)
        loss_info = agent.train(experience)
        self.evaluate(loss_info)
        final_weights = self.evaluate(agent.weights)
        update = final_weights - initial_weights

        # Check that the actual update matches expectations.
        self.assertAllClose(expected_update, update)