def testRenamedSignatures(self):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')

        time_step_spec = self._time_step_spec._replace(
            observation=tensor_spec.BoundedTensorSpec(
                dtype=tf.float32, shape=(4, ), minimum=-10.0, maximum=10.0))

        network = q_network.QNetwork(
            input_tensor_spec=time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        action_signature_names = [
            s.name for s in saver._signatures['action'].input_signature
        ]
        self.assertAllEqual(
            ['0/step_type', '0/reward', '0/discount', '0/observation'],
            action_signature_names)
        initial_state_signature_names = [
            s.name
            for s in saver._signatures['get_initial_state'].input_signature
        ]
        self.assertAllEqual(['batch_size'], initial_state_signature_names)
 def testTrain(self):
     with tf.compat.v2.summary.record_if(False):
         # Emits trajectories shaped (batch=1, time=6, ...)
         traj, time_step_spec, action_spec = (
             driver_test_utils.make_random_trajectory())
         # Convert to shapes (batch=6, 1, ...) so this works with a non-RNN model.
         traj = tf.nest.map_structure(common.transpose_batch_time, traj)
         cloning_net = q_network.QNetwork(time_step_spec.observation,
                                          action_spec)
         agent = behavioral_cloning_agent.BehavioralCloningAgent(
             time_step_spec,
             action_spec,
             cloning_network=cloning_net,
             optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.01))
         # Disable clipping to make sure we can see the difference in behavior
         agent.policy._clip = False
         # Remove policy_info, as BehavioralCloningAgent expects none.
         traj = traj.replace(policy_info=())
         # TODO(b/123883319)
         if tf.executing_eagerly():
             train_and_loss = lambda: agent.train(traj)
         else:
             train_and_loss = agent.train(traj)
         replay = trajectory_replay.TrajectoryReplay(agent.policy)
         self.evaluate(tf.compat.v1.global_variables_initializer())
         initial_actions = self.evaluate(replay.run(traj)[0])
         for _ in range(TRAIN_ITERATIONS):
             self.evaluate(train_and_loss)
         post_training_actions = self.evaluate(replay.run(traj)[0])
         # We don't necessarily converge to the same actions as in trajectory after
         # 10 steps of an untuned optimizer, but the policy does change.
         self.assertFalse(np.all(initial_actions == post_training_actions))
    def testUniqueSignatures(self):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')

        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        action_signature_names = [
            s.name for s in saver._signatures['action'].input_signature
        ]
        self.assertAllEqual(
            ['0/step_type', '0/reward', '0/discount', '0/observation'],
            action_signature_names)
        initial_state_signature_names = [
            s.name
            for s in saver._signatures['get_initial_state'].input_signature
        ]
        self.assertAllEqual(['batch_size'], initial_state_signature_names)
    def __init__(self, env):
        # Agent初期化
        self.env = env
        q_net = q_network.QNetwork(
            env.observation_spec(),
            env.action_spec(),
            fc_layer_params=fc_layer_params,
        )

        adam = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate, beta1=0.8, epsilon=1)

        train_step_counter = tf.compat.v2.Variable(0)

        self.agent = dqn_agent.DqnAgent(
            env.time_step_spec(),
            env.action_spec(),
            q_network=q_net,
            optimizer=adam,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            train_step_counter=train_step_counter,
        )
        self.agent.initialize()
        self._create_replay_buffer()
        eval_env = BlackJackEnv.tf_env()
        self.evaluator = PolicyEvaluator(eval_env, n_eval_episodes)
Exemple #5
0
 def testTrain(self):
     # Emits trajectories shaped (batch=1, time=6, ...)
     traj, time_step_spec, action_spec = (
         driver_test_utils.make_random_trajectory())
     # Convert to shapes (batch=6, 1, ...) so this works with a non-RNN model.
     traj = nest.map_structure(tf.contrib.rnn.transpose_batch_time, traj)
     cloning_net = q_network.QNetwork(time_step_spec.observation,
                                      action_spec)
     agent = behavioral_cloning_agent.BehavioralCloningAgent(
         time_step_spec,
         action_spec,
         cloning_network=cloning_net,
         optimizer=tf.train.AdamOptimizer(learning_rate=0.01))
     # Remove policy_info, as BehavioralCloningAgent expects none.
     traj = traj.replace(policy_info=())
     train_and_loss = agent.train(traj)
     replay = trajectory_replay.TrajectoryReplay(agent.policy())
     self.evaluate(tf.global_variables_initializer())
     initial_actions = self.evaluate(replay.run(traj)[0])
     for _ in range(TRAIN_ITERATIONS):
         self.evaluate(train_and_loss)
     post_training_actions = self.evaluate(replay.run(traj)[0])
     # We don't necessarily converge to the same actions as in trajectory after
     # 10 steps of an untuned optimizer, but the policy does change.
     self.assertFalse(np.all(initial_actions == post_training_actions))
 def testVariablesBuild(self):
   num_state_dims = 5
   network = q_network.QNetwork(
       input_tensor_spec=tensor_spec.TensorSpec([num_state_dims], tf.float32),
       action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1))
   self.assertFalse(network.built)
   variables = network.variables
   self.assertTrue(network.built)
   self.assertGreater(len(variables), 0)
 def testCorrectOutputShape(self):
   batch_size = 3
   num_state_dims = 5
   num_actions = 2
   states = tf.random.uniform([batch_size, num_state_dims])
   network = q_network.QNetwork(
       input_tensor_spec=tensor_spec.TensorSpec([num_state_dims], tf.float32),
       action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1))
   q_values, _ = network(states)
   self.assertAllEqual(q_values.shape.as_list(), [batch_size, num_actions])
 def testNetworkVariablesAreReused(self):
   batch_size = 3
   num_state_dims = 5
   states = tf.ones([batch_size, num_state_dims])
   next_states = tf.ones([batch_size, num_state_dims])
   network = q_network.QNetwork(
       input_tensor_spec=tensor_spec.TensorSpec([num_state_dims], tf.float32),
       action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1))
   q_values, _ = network(states)
   next_q_values, _ = network(next_states)
   self.evaluate(tf.compat.v1.global_variables_initializer())
   self.assertAllClose(q_values, next_q_values)
 def testChangeHiddenLayers(self):
   batch_size = 3
   num_state_dims = 5
   num_actions = 2
   states = tf.random.uniform([batch_size, num_state_dims])
   network = q_network.QNetwork(
       input_tensor_spec=tensor_spec.TensorSpec([num_state_dims], tf.float32),
       action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1),
       fc_layer_params=(40,))
   q_values, _ = network(states)
   self.assertAllEqual(q_values.shape.as_list(), [batch_size, num_actions])
   self.assertEqual(len(network.trainable_variables), 4)
Exemple #10
0
 def testBuild(self):
     batch_size = 3
     num_state_dims = 5
     num_actions = 2
     states = tf.random_uniform([batch_size, num_state_dims])
     network = q_network.QNetwork(
         input_tensor_spec=tensor_spec.TensorSpec([num_state_dims],
                                                  tf.float32),
         action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1))
     q_values, _ = network(states)
     self.assertAllEqual(q_values.shape.as_list(),
                         [batch_size, num_actions])
     self.assertEqual(len(network.trainable_weights), 6)
Exemple #11
0
 def testAddConvLayers(self):
     batch_size = 3
     num_state_dims = 5
     num_actions = 2
     states = tf.random_uniform([batch_size, 5, 5, num_state_dims])
     network = q_network.QNetwork(
         observation_spec=tensor_spec.TensorSpec([5, 5, num_state_dims],
                                                 tf.float32),
         action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1),
         conv_layer_params=((16, 3, 2), ))
     q_values, _ = network(states)
     self.assertAllEqual(q_values.shape.as_list(),
                         [batch_size, num_actions])
     self.assertEqual(len(network.trainable_variables), 8)
Exemple #12
0
  def testAgentFollowsActionSpec(self, agent_class):
    agent = agent_class(
        self._time_step_spec,
        self._action_spec,
        q_network=q_network.QNetwork(self._observation_spec, self._action_spec),
        optimizer=None)
    self.assertTrue(agent.policy() is not None)
    policy = agent.policy()
    observation = tensor_spec.sample_spec_nest(
        self._time_step_spec, seed=42, outer_dims=(1,))
    action_op = policy.action(observation).action
    self.evaluate(tf.initialize_all_variables())

    action = self.evaluate(action_op)
    self.assertEqual([1] + self._action_spec[0].shape.as_list(),
                     list(action[0].shape))
Exemple #13
0
    def testAgentFollowsActionSpecWithScalarAction(self, agent_class):
        action_spec = [tensor_spec.BoundedTensorSpec((), tf.int32, 0, 1)]
        agent = agent_class(self._time_step_spec,
                            action_spec,
                            q_network=q_network.QNetwork(
                                self._observation_spec, action_spec),
                            optimizer=None)
        self.assertIsNotNone(agent.policy)
        policy = agent.policy
        observation = tensor_spec.sample_spec_nest(self._time_step_spec,
                                                   seed=42,
                                                   outer_dims=(1, ))

        action_op = policy.action(observation).action
        self.evaluate(tf.compat.v1.initialize_all_variables())
        action = self.evaluate(action_op)
        self.assertEqual([1] + action_spec[0].shape.as_list(),
                         list(action[0].shape))
Exemple #14
0
def create_agent(train_env):
  q_net = q_network.QNetwork(
      train_env.observation_spec(),
      train_env.action_spec(),
      fc_layer_params=fc_layer_params,
  )

  adam = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

  train_step_counter = tf.compat.v2.Variable(0)

  tf_agent = dqn_agent.DqnAgent(
      train_env.time_step_spec(),
      train_env.action_spec(),
      q_network=q_net,
      optimizer=adam,
      td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
      train_step_counter=train_step_counter,
  )
  tf_agent.initialize()
  return tf_agent
 def testAddPreprocessingLayers(self):
   batch_size = 3
   num_actions = 2
   states = (tf.random.uniform([batch_size, 1]),
             tf.random.uniform([batch_size]))
   preprocessing_layers = (
       tf.keras.layers.Dense(4),
       tf.keras.Sequential([
           tf.keras.layers.Reshape((1,)),
           tf.keras.layers.Dense(4)]))
   network = q_network.QNetwork(
       input_tensor_spec=(
           tensor_spec.TensorSpec([1], tf.float32),
           tensor_spec.TensorSpec([], tf.float32)),
       preprocessing_layers=preprocessing_layers,
       preprocessing_combiner=tf.keras.layers.Add(),
       action_spec=tensor_spec.BoundedTensorSpec(
           [1], tf.int32, 0, num_actions - 1))
   q_values, _ = network(states)
   self.assertAllEqual(q_values.shape.as_list(), [batch_size, num_actions])
   # At least 2 variables each for the preprocessing layers.
   self.assertGreater(len(network.trainable_variables), 4)
Exemple #16
0
batch_size = 128  # @param
learning_rate = 1e-5  # @param
log_interval = 200  # @param

num_eval_episodes = 2  # @param
eval_interval = 1000  # @param

train_py_env = wrappers.TimeLimit(GridWorldEnv(), duration=100)
eval_py_env = wrappers.TimeLimit(GridWorldEnv(), duration=100)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

q_net = q_network.QNetwork(train_env.observation_spec(),
                           train_env.action_spec(),
                           fc_layer_params=fc_layer_params)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.compat.v2.Variable(0)

tf_agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
    train_step_counter=train_step_counter)

tf_agent.initialize()
Exemple #17
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        agent_class=dqn_agent.DqnAgent,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.contrib.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.contrib.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    with tf.contrib.summary.record_summaries_every_n_global_steps(
            summary_interval):

        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_py_env = suite_gym.load(env_name)

        q_net = q_network.QNetwork(tf_env.time_step_spec().observation,
                                   tf_env.action_spec(),
                                   fc_layer_params=fc_layer_params)

        tf_agent = agent_class(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate),
            # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec(),
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy())

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        global_step = tf.train.get_or_create_global_step()

        replay_observer = [replay_buffer.add_batch]
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer,
            num_steps=initial_collect_steps).run()

        collect_policy = tf_agent.collect_policy()
        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)

        iterator = dataset.make_initializable_iterator()
        trajectories, _ = iterator.get_next()
        train_op = tf_agent.train(experience=trajectories,
                                  train_step_counter=global_step)

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=tf.contrib.checkpoint.List(train_metrics))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy(),
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(step_metrics=train_metrics[:2])
        summary_op = tf.contrib.summary.all_summary_ops()

        with eval_summary_writer.as_default(), \
             tf.contrib.summary.always_record_summaries():
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            tf.contrib.summary.initialize(session=sess)
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(
                [train_op, summary_op, global_step])

            timed_at_step = sess.run(global_step)
            collect_time = 0
            train_time = 0
            steps_per_second_ph = tf.placeholder(tf.float32,
                                                 shape=(),
                                                 name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                # Train/collect/eval.
                start_time = time.time()
                collect_call()
                collect_time += time.time() - start_time
                start_time = time.time()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _, global_step_val = train_step_call()
                train_time += time.time() - start_time

                if global_step_val % log_interval == 0:
                    tf.logging.info('step = %d, loss = %f', global_step_val,
                                    loss_info_value.loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    tf.logging.info('%.3f steps/sec' % steps_per_sec)
                    tf.logging.info(
                        'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
Exemple #18
0
def train_eval(
    root_dir,
    env_name='CartPole-v0',
    num_iterations=100000,
    fc_layer_params=(100,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    epsilon_greedy=0.1,
    replay_buffer_capacity=100000,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    learning_rate=1e-3,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=1000,
    # Params for summaries and logging
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for DQN."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.contrib.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.contrib.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  # TODO(kbanoop): Figure out if it is possible to avoid the with block.
  with tf.contrib.summary.record_summaries_every_n_global_steps(
      summary_interval):

    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
    eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

    trajectory_spec = trajectory.from_transition(
        time_step=tf_env.time_step_spec(),
        action_step=policy_step.PolicyStep(action=tf_env.action_spec()),
        next_time_step=tf_env.time_step_spec())
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=trajectory_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    q_net = q_network.QNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=fc_layer_params)

    tf_agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839
        epsilon_greedy=epsilon_greedy,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate),
        td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars)

    tf_agent.initialize()
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    eval_policy = tf_agent.policy()
    collect_policy = tf_agent.collect_policy()

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration)

    global_step = tf.train.get_or_create_global_step()

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())

    # Collect initial replay data.
    tf.logging.info(
        'Initializing replay buffer by collecting experience for %d steps with '
        'a random policy.' % initial_collect_steps)
    dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=[replay_buffer.add_batch],
        num_steps=initial_collect_steps).run()

    metrics = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(metrics, global_step.numpy())

    time_step = None
    policy_state = ()

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        experience, _ = next(iterator)
        train_loss = tf_agent.train(experience, train_step_counter=global_step)
      time_acc += time.time() - start_time

      if global_step.numpy() % log_interval == 0:
        tf.logging.info('step = %d, loss = %f', global_step.numpy(), train_loss)
        steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
        tf.logging.info('%.3f steps/sec' % steps_per_sec)
        tf.contrib.summary.scalar(name='global_steps/sec', tensor=steps_per_sec)
        timed_at_step = global_step.numpy()
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(step_metrics=train_metrics[:2])

      if global_step.numpy() % eval_interval == 0:
        metrics = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(metrics, global_step.numpy())
    return train_loss
    def testSaveAction(self, seeded, has_state):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')

        if has_state:
            network = q_rnn_network.QRnnNetwork(
                input_tensor_spec=self._time_step_spec.observation,
                action_spec=self._action_spec)
        else:
            network = q_network.QNetwork(
                input_tensor_spec=self._time_step_spec.observation,
                action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        action_seed = 98723

        saver = policy_saver.PolicySaver(policy,
                                         batch_size=None,
                                         use_nest_path_signatures=False,
                                         seed=action_seed)
        path = os.path.join(self.get_temp_dir(), 'save_model_action')
        saver.save(path)

        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertIn('action', reloaded.signatures)
        reloaded_action = reloaded.signatures['action']
        self._compare_input_output_specs(
            reloaded_action,
            expected_input_specs=(self._time_step_spec,
                                  policy.policy_state_spec),
            expected_output_spec=policy.policy_step_spec,
            batch_input=True)

        batch_size = 3

        action_inputs = tensor_spec.sample_spec_nest(
            (self._time_step_spec, policy.policy_state_spec),
            outer_dims=(batch_size, ),
            seed=4)

        function_action_input_dict = dict(
            (spec.name, value) for (spec, value) in zip(
                tf.nest.flatten((self._time_step_spec, policy.policy_state_spec
                                 )), tf.nest.flatten(action_inputs)))

        # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded model
        # are equal, which in addition to seeding the call to action() and
        # PolicySaver helps ensure equality of the output of action() in both cases.
        self.assertEqual(reloaded_action.graph.seed, self._global_seed)
        action_output = policy.action(*action_inputs, seed=action_seed)

        # The seed= argument for the SavedModel action call was given at creation of
        # the PolicySaver.

        # This is the flat-signature function.
        reloaded_action_output_dict = reloaded_action(
            **function_action_input_dict)

        def match_dtype_shape(x, y, msg=None):
            self.assertEqual(x.shape, y.shape, msg=msg)
            self.assertEqual(x.dtype, y.dtype, msg=msg)

        # This is the non-flat function.
        if has_state:
            reloaded_action_output = reloaded.action(*action_inputs)
        else:
            # Try both cases: one with an empty policy_state and one with no
            # policy_state.  Compare them.

            # NOTE(ebrevdo): The first call to .action() must be stored in
            # reloaded_action_output because this is the version being compared later
            # against the true action_output and the values will change after the
            # first call due to randomness.
            reloaded_action_output = reloaded.action(*action_inputs)
            reloaded_action_output_no_input_state = reloaded.action(
                action_inputs[0])
            # Even with a seed, multiple calls to action will get different values,
            # so here we just check the signature matches.
            tf.nest.map_structure(match_dtype_shape,
                                  reloaded_action_output_no_input_state,
                                  reloaded_action_output)

        action_output_dict = dict(
            ((spec.name, value)
             for (spec, value) in zip(tf.nest.flatten(policy.policy_step_spec),
                                      tf.nest.flatten(action_output))))

        # Check output of the flattened signature call.
        action_output_dict = self.evaluate(action_output_dict)
        reloaded_action_output_dict = self.evaluate(
            reloaded_action_output_dict)
        self.assertAllEqual(action_output_dict.keys(),
                            reloaded_action_output_dict.keys())

        for k in action_output_dict:
            if seeded:
                self.assertAllClose(action_output_dict[k],
                                    reloaded_action_output_dict[k],
                                    msg='\nMismatched dict key: %s.' % k)
            else:
                match_dtype_shape(action_output_dict[k],
                                  reloaded_action_output_dict[k],
                                  msg='\nMismatch dict key: %s.' % k)

        # Check output of the proper structured call.
        action_output = self.evaluate(action_output)
        reloaded_action_output = self.evaluate(reloaded_action_output)
        # With non-signature functions, we can check that passing a seed does the
        # right thing the second time.
        if seeded:
            tf.nest.map_structure(self.assertAllClose, action_output,
                                  reloaded_action_output)
        else:
            tf.nest.map_structure(match_dtype_shape, action_output,
                                  reloaded_action_output)
Exemple #20
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        log_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    # Note this is a python environment.
    env = batched_py_environment.BatchedPyEnvironment(
        [suite_gym.load(env_name)])
    eval_py_env = suite_gym.load(env_name)

    # Convert specs to BoundedTensorSpec.
    action_spec = tensor_spec.from_spec(env.action_spec())
    observation_spec = tensor_spec.from_spec(env.observation_spec())
    time_step_spec = ts.time_step_spec(observation_spec)

    q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()),
                               tensor_spec.from_spec(env.action_spec()),
                               fc_layer_params=fc_layer_params)

    # The agent must be in graph.
    global_step = tf.compat.v1.train.get_or_create_global_step()
    agent = dqn_agent.DqnAgent(
        time_step_spec,
        action_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate),
        td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy)
    greedy_policy = py_tf_policy.PyTFPolicy(agent.policy)
    random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(),
                                                    env.action_spec())

    # Python replay buffer.
    replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
        capacity=replay_buffer_capacity,
        data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec))

    time_step = env.reset()

    # Initialize the replay buffer with some transitions. We use the random
    # policy to initialize the replay buffer to make sure we get a good
    # distribution of actions.
    for _ in range(initial_collect_steps):
        time_step = collect_step(env, time_step, random_policy, replay_buffer)

    # TODO(b/112041045) Use global_step as counter.
    train_checkpointer = common.Checkpointer(ckpt_dir=train_dir,
                                             agent=agent,
                                             global_step=global_step)

    policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'policy'),
                                              policy=agent.policy,
                                              global_step=global_step)

    ds = replay_buffer.as_dataset(sample_batch_size=batch_size, num_steps=2)
    ds = ds.prefetch(4)
    itr = tf.compat.v1.data.make_initializable_iterator(ds)

    experience = itr.get_next()

    train_op = common.function(agent.train)(experience)

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
        for eval_metric in eval_metrics:
            eval_metric.tf_summaries()

    with tf.compat.v1.Session() as session:
        train_checkpointer.initialize_or_restore(session)
        common.initialize_uninitialized_variables(session)
        session.run(itr.initializer)
        # Copy critic network values to the target critic network.
        session.run(agent.initialize())
        train = session.make_callable(train_op)
        global_step_call = session.make_callable(global_step)
        session.run(train_summary_writer.init())
        session.run(eval_summary_writer.init())

        # Compute inital evaluation metrics.
        global_step_val = global_step_call()
        metric_utils.compute_summaries(
            eval_metrics,
            eval_py_env,
            greedy_policy,
            num_episodes=num_eval_episodes,
            global_step=global_step_val,
            log=True,
            callback=eval_metrics_callback,
        )

        timed_at_step = global_step_val
        collect_time = 0
        train_time = 0
        steps_per_second_ph = tf.compat.v1.placeholder(tf.float32,
                                                       shape=(),
                                                       name='steps_per_sec_ph')
        steps_per_second_summary = tf.contrib.summary.scalar(
            name='global_steps/sec', tensor=steps_per_second_ph)
        for _ in range(num_iterations):
            start_time = time.time()
            for _ in range(collect_steps_per_iteration):
                time_step = collect_step(env, time_step, collect_policy,
                                         replay_buffer)
            collect_time += time.time() - start_time
            start_time = time.time()
            for _ in range(train_steps_per_iteration):
                loss = train()
            train_time += time.time() - start_time
            global_step_val = global_step_call()
            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             loss.loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                session.run(steps_per_second_summary,
                            feed_dict={steps_per_second_ph: steps_per_sec})
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info(
                    '%s', 'collect_time = {}, train_time = {}'.format(
                        collect_time, train_time))
                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % eval_interval == 0:
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    greedy_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    log=True,
                    callback=eval_metrics_callback,
                )
                # Reset timing to avoid counting eval time.
                timed_at_step = global_step_val
                start_time = time.time()
Exemple #21
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=(100, ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),

        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        if train_sequence_length > 1:
            q_net = q_rnn_network.QRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=input_fc_layer_params,
                lstm_size=lstm_size,
                output_fc_layer_params=output_fc_layer_params)
        else:
            q_net = q_network.QNetwork(tf_env.observation_spec(),
                                       tf_env.action_spec(),
                                       fc_layer_params=fc_layer_params)

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                experience, _ = next(iterator)
                train_loss = tf_agent.train(experience)
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.contrib.summary.scalar(name='global_steps/sec',
                                          tensor=steps_per_sec)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
        return train_loss