Beispiel #1
0
 def testTrain(self):
     # Emits trajectories shaped (batch=1, time=6, ...)
     traj, time_step_spec, action_spec = (
         driver_test_utils.make_random_trajectory())
     # Convert to shapes (batch=6, 1, ...) so this works with a non-RNN model.
     traj = nest.map_structure(tf.contrib.rnn.transpose_batch_time, traj)
     cloning_net = q_network.QNetwork(time_step_spec.observation,
                                      action_spec)
     agent = behavioral_cloning_agent.BehavioralCloningAgent(
         time_step_spec,
         action_spec,
         cloning_network=cloning_net,
         optimizer=tf.train.AdamOptimizer(learning_rate=0.01))
     # Remove policy_info, as BehavioralCloningAgent expects none.
     traj = traj.replace(policy_info=())
     train_and_loss = agent.train(traj)
     replay = trajectory_replay.TrajectoryReplay(agent.policy())
     self.evaluate(tf.global_variables_initializer())
     initial_actions = self.evaluate(replay.run(traj)[0])
     for _ in range(TRAIN_ITERATIONS):
         self.evaluate(train_and_loss)
     post_training_actions = self.evaluate(replay.run(traj)[0])
     # We don't necessarily converge to the same actions as in trajectory after
     # 10 steps of an untuned optimizer, but the policy does change.
     self.assertFalse(np.all(initial_actions == post_training_actions))
  def testTrainWithRNN(self):
    # Emits trajectories shaped (batch=1, time=6, ...)
    traj, time_step_spec, action_spec = (
        driver_test_utils.make_random_trajectory())
    cloning_net = q_rnn_network.QRnnNetwork(
        time_step_spec.observation, action_spec)
    agent = behavioral_cloning_agent.BehavioralCloningAgent(
        time_step_spec,
        action_spec,
        cloning_network=cloning_net,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.01),
        num_outer_dims=2)
    # Disable clipping to make sure we can see the difference in behavior
    agent.policy._clip = False
    # Remove policy_info, as BehavioralCloningAgent expects none.
    traj = traj.replace(policy_info=())
    # TODO(b/123883319)
    if tf.executing_eagerly():
      train_and_loss = lambda: agent.train(traj)
    else:
      train_and_loss = agent.train(traj)
    replay = trajectory_replay.TrajectoryReplay(agent.policy)
    self.evaluate(tf.compat.v1.global_variables_initializer())
    initial_actions = self.evaluate(replay.run(traj)[0])

    for _ in range(TRAIN_ITERATIONS):
      self.evaluate(train_and_loss)
    post_training_actions = self.evaluate(replay.run(traj)[0])
    # We don't necessarily converge to the same actions as in trajectory after
    # 10 steps of an untuned optimizer, but the policy does change.
    self.assertFalse(np.all(initial_actions == post_training_actions))
    def testReplayBufferObserversWithInitialState(self):
        traj, time_step_spec, action_spec = make_random_trajectory()
        policy = driver_test_utils.TFPolicyMock(time_step_spec, action_spec)
        policy_state = policy.get_initial_state(1)
        replay = trajectory_replay.TrajectoryReplay(policy)
        output_actions, output_policy_info, _ = replay.run(
            traj, policy_state=policy_state)
        new_traj = traj._replace(action=output_actions,
                                 policy_info=output_policy_info)
        repeat_output_actions, repeat_output_policy_info, _ = replay.run(
            new_traj, policy_state=policy_state)
        self.evaluate(tf.global_variables_initializer())
        (output_actions, output_policy_info, traj, repeat_output_actions,
         repeat_output_policy_info) = self.evaluate(
             (output_actions, output_policy_info, traj, repeat_output_actions,
              repeat_output_policy_info))

        # Ensure output actions & policy info don't match original trajectory.
        self._compare_to_original(output_actions, output_policy_info, traj)

        # Ensure repeated run with the same deterministic policy recreates the same
        # actions & policy info.
        nest.map_structure(self.assertAllEqual, output_actions,
                           repeat_output_actions)
        nest.map_structure(self.assertAllEqual, output_policy_info,
                           repeat_output_policy_info)