def testRenamedSignatures(self): if not tf.executing_eagerly(): self.skipTest( 'b/129079730: PolicySaver does not work in TF1.x yet') time_step_spec = self._time_step_spec._replace( observation=tensor_spec.BoundedTensorSpec( dtype=tf.float32, shape=(4, ), minimum=-10.0, maximum=10.0)) network = q_network.QNetwork( input_tensor_spec=time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=time_step_spec, action_spec=self._action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) action_signature_names = [ s.name for s in saver._signatures['action'].input_signature ] self.assertAllEqual( ['0/step_type', '0/reward', '0/discount', '0/observation'], action_signature_names) initial_state_signature_names = [ s.name for s in saver._signatures['get_initial_state'].input_signature ] self.assertAllEqual(['batch_size'], initial_state_signature_names)
def testTrain(self): with tf.compat.v2.summary.record_if(False): # Emits trajectories shaped (batch=1, time=6, ...) traj, time_step_spec, action_spec = ( driver_test_utils.make_random_trajectory()) # Convert to shapes (batch=6, 1, ...) so this works with a non-RNN model. traj = tf.nest.map_structure(common.transpose_batch_time, traj) cloning_net = q_network.QNetwork(time_step_spec.observation, action_spec) agent = behavioral_cloning_agent.BehavioralCloningAgent( time_step_spec, action_spec, cloning_network=cloning_net, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.01)) # Disable clipping to make sure we can see the difference in behavior agent.policy._clip = False # Remove policy_info, as BehavioralCloningAgent expects none. traj = traj.replace(policy_info=()) # TODO(b/123883319) if tf.executing_eagerly(): train_and_loss = lambda: agent.train(traj) else: train_and_loss = agent.train(traj) replay = trajectory_replay.TrajectoryReplay(agent.policy) self.evaluate(tf.compat.v1.global_variables_initializer()) initial_actions = self.evaluate(replay.run(traj)[0]) for _ in range(TRAIN_ITERATIONS): self.evaluate(train_and_loss) post_training_actions = self.evaluate(replay.run(traj)[0]) # We don't necessarily converge to the same actions as in trajectory after # 10 steps of an untuned optimizer, but the policy does change. self.assertFalse(np.all(initial_actions == post_training_actions))
def testUniqueSignatures(self): if not tf.executing_eagerly(): self.skipTest( 'b/129079730: PolicySaver does not work in TF1.x yet') network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) action_signature_names = [ s.name for s in saver._signatures['action'].input_signature ] self.assertAllEqual( ['0/step_type', '0/reward', '0/discount', '0/observation'], action_signature_names) initial_state_signature_names = [ s.name for s in saver._signatures['get_initial_state'].input_signature ] self.assertAllEqual(['batch_size'], initial_state_signature_names)
def __init__(self, env): # Agent初期化 self.env = env q_net = q_network.QNetwork( env.observation_spec(), env.action_spec(), fc_layer_params=fc_layer_params, ) adam = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate, beta1=0.8, epsilon=1) train_step_counter = tf.compat.v2.Variable(0) self.agent = dqn_agent.DqnAgent( env.time_step_spec(), env.action_spec(), q_network=q_net, optimizer=adam, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, train_step_counter=train_step_counter, ) self.agent.initialize() self._create_replay_buffer() eval_env = BlackJackEnv.tf_env() self.evaluator = PolicyEvaluator(eval_env, n_eval_episodes)
def testTrain(self): # Emits trajectories shaped (batch=1, time=6, ...) traj, time_step_spec, action_spec = ( driver_test_utils.make_random_trajectory()) # Convert to shapes (batch=6, 1, ...) so this works with a non-RNN model. traj = nest.map_structure(tf.contrib.rnn.transpose_batch_time, traj) cloning_net = q_network.QNetwork(time_step_spec.observation, action_spec) agent = behavioral_cloning_agent.BehavioralCloningAgent( time_step_spec, action_spec, cloning_network=cloning_net, optimizer=tf.train.AdamOptimizer(learning_rate=0.01)) # Remove policy_info, as BehavioralCloningAgent expects none. traj = traj.replace(policy_info=()) train_and_loss = agent.train(traj) replay = trajectory_replay.TrajectoryReplay(agent.policy()) self.evaluate(tf.global_variables_initializer()) initial_actions = self.evaluate(replay.run(traj)[0]) for _ in range(TRAIN_ITERATIONS): self.evaluate(train_and_loss) post_training_actions = self.evaluate(replay.run(traj)[0]) # We don't necessarily converge to the same actions as in trajectory after # 10 steps of an untuned optimizer, but the policy does change. self.assertFalse(np.all(initial_actions == post_training_actions))
def testVariablesBuild(self): num_state_dims = 5 network = q_network.QNetwork( input_tensor_spec=tensor_spec.TensorSpec([num_state_dims], tf.float32), action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1)) self.assertFalse(network.built) variables = network.variables self.assertTrue(network.built) self.assertGreater(len(variables), 0)
def testCorrectOutputShape(self): batch_size = 3 num_state_dims = 5 num_actions = 2 states = tf.random.uniform([batch_size, num_state_dims]) network = q_network.QNetwork( input_tensor_spec=tensor_spec.TensorSpec([num_state_dims], tf.float32), action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1)) q_values, _ = network(states) self.assertAllEqual(q_values.shape.as_list(), [batch_size, num_actions])
def testNetworkVariablesAreReused(self): batch_size = 3 num_state_dims = 5 states = tf.ones([batch_size, num_state_dims]) next_states = tf.ones([batch_size, num_state_dims]) network = q_network.QNetwork( input_tensor_spec=tensor_spec.TensorSpec([num_state_dims], tf.float32), action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1)) q_values, _ = network(states) next_q_values, _ = network(next_states) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertAllClose(q_values, next_q_values)
def testChangeHiddenLayers(self): batch_size = 3 num_state_dims = 5 num_actions = 2 states = tf.random.uniform([batch_size, num_state_dims]) network = q_network.QNetwork( input_tensor_spec=tensor_spec.TensorSpec([num_state_dims], tf.float32), action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1), fc_layer_params=(40,)) q_values, _ = network(states) self.assertAllEqual(q_values.shape.as_list(), [batch_size, num_actions]) self.assertEqual(len(network.trainable_variables), 4)
def testBuild(self): batch_size = 3 num_state_dims = 5 num_actions = 2 states = tf.random_uniform([batch_size, num_state_dims]) network = q_network.QNetwork( input_tensor_spec=tensor_spec.TensorSpec([num_state_dims], tf.float32), action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1)) q_values, _ = network(states) self.assertAllEqual(q_values.shape.as_list(), [batch_size, num_actions]) self.assertEqual(len(network.trainable_weights), 6)
def testAddConvLayers(self): batch_size = 3 num_state_dims = 5 num_actions = 2 states = tf.random_uniform([batch_size, 5, 5, num_state_dims]) network = q_network.QNetwork( observation_spec=tensor_spec.TensorSpec([5, 5, num_state_dims], tf.float32), action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1), conv_layer_params=((16, 3, 2), )) q_values, _ = network(states) self.assertAllEqual(q_values.shape.as_list(), [batch_size, num_actions]) self.assertEqual(len(network.trainable_variables), 8)
def testAgentFollowsActionSpec(self, agent_class): agent = agent_class( self._time_step_spec, self._action_spec, q_network=q_network.QNetwork(self._observation_spec, self._action_spec), optimizer=None) self.assertTrue(agent.policy() is not None) policy = agent.policy() observation = tensor_spec.sample_spec_nest( self._time_step_spec, seed=42, outer_dims=(1,)) action_op = policy.action(observation).action self.evaluate(tf.initialize_all_variables()) action = self.evaluate(action_op) self.assertEqual([1] + self._action_spec[0].shape.as_list(), list(action[0].shape))
def testAgentFollowsActionSpecWithScalarAction(self, agent_class): action_spec = [tensor_spec.BoundedTensorSpec((), tf.int32, 0, 1)] agent = agent_class(self._time_step_spec, action_spec, q_network=q_network.QNetwork( self._observation_spec, action_spec), optimizer=None) self.assertIsNotNone(agent.policy) policy = agent.policy observation = tensor_spec.sample_spec_nest(self._time_step_spec, seed=42, outer_dims=(1, )) action_op = policy.action(observation).action self.evaluate(tf.compat.v1.initialize_all_variables()) action = self.evaluate(action_op) self.assertEqual([1] + action_spec[0].shape.as_list(), list(action[0].shape))
def create_agent(train_env): q_net = q_network.QNetwork( train_env.observation_spec(), train_env.action_spec(), fc_layer_params=fc_layer_params, ) adam = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) train_step_counter = tf.compat.v2.Variable(0) tf_agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=adam, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, train_step_counter=train_step_counter, ) tf_agent.initialize() return tf_agent
def testAddPreprocessingLayers(self): batch_size = 3 num_actions = 2 states = (tf.random.uniform([batch_size, 1]), tf.random.uniform([batch_size])) preprocessing_layers = ( tf.keras.layers.Dense(4), tf.keras.Sequential([ tf.keras.layers.Reshape((1,)), tf.keras.layers.Dense(4)])) network = q_network.QNetwork( input_tensor_spec=( tensor_spec.TensorSpec([1], tf.float32), tensor_spec.TensorSpec([], tf.float32)), preprocessing_layers=preprocessing_layers, preprocessing_combiner=tf.keras.layers.Add(), action_spec=tensor_spec.BoundedTensorSpec( [1], tf.int32, 0, num_actions - 1)) q_values, _ = network(states) self.assertAllEqual(q_values.shape.as_list(), [batch_size, num_actions]) # At least 2 variables each for the preprocessing layers. self.assertGreater(len(network.trainable_variables), 4)
batch_size = 128 # @param learning_rate = 1e-5 # @param log_interval = 200 # @param num_eval_episodes = 2 # @param eval_interval = 1000 # @param train_py_env = wrappers.TimeLimit(GridWorldEnv(), duration=100) eval_py_env = wrappers.TimeLimit(GridWorldEnv(), duration=100) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) q_net = q_network.QNetwork(train_env.observation_spec(), train_env.action_spec(), fc_layer_params=fc_layer_params) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) train_step_counter = tf.compat.v2.Variable(0) tf_agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, train_step_counter=train_step_counter) tf_agent.initialize()
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, agent_class=dqn_agent.DqnAgent, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_py_env = suite_gym.load(env_name) q_net = q_network.QNetwork(tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=fc_layer_params) tf_agent = agent_class( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate), # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839 epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec(), batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy()) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] global_step = tf.train.get_or_create_global_step() replay_observer = [replay_buffer.add_batch] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps).run() collect_policy = tf_agent.collect_policy() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = dataset.make_initializable_iterator() trajectories, _ = iterator.get_next() train_op = tf_agent.train(experience=trajectories, train_step_counter=global_step) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=tf.contrib.checkpoint.List(train_metrics)) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy(), global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) summary_op = tf.contrib.summary.all_summary_ops() with eval_summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) tf.contrib.summary.initialize(session=sess) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable( [train_op, summary_op, global_step]) timed_at_step = sess.run(global_step) collect_time = 0 train_time = 0 steps_per_second_ph = tf.placeholder(tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): # Train/collect/eval. start_time = time.time() collect_call() collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() train_time += time.time() - start_time if global_step_val % log_interval == 0: tf.logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) tf.logging.info('%.3f steps/sec' % steps_per_sec) tf.logging.info( 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] # TODO(kbanoop): Figure out if it is possible to avoid the with block. with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) trajectory_spec = trajectory.from_transition( time_step=tf_env.time_step_spec(), action_step=policy_step.PolicyStep(action=tf_env.action_spec()), next_time_step=tf_env.time_step_spec()) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=trajectory_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) q_net = q_network.QNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=fc_layer_params) tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839 epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate), td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy() collect_policy = tf_agent.collect_policy() collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) global_step = tf.train.get_or_create_global_step() initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. tf.logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.' % initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch], num_steps=initial_collect_steps).run() metrics = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(metrics, global_step.numpy()) time_step = None policy_state = () timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience, train_step_counter=global_step) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: tf.logging.info('step = %d, loss = %f', global_step.numpy(), train_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc tf.logging.info('%.3f steps/sec' % steps_per_sec) tf.contrib.summary.scalar(name='global_steps/sec', tensor=steps_per_sec) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: metrics = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(metrics, global_step.numpy()) return train_loss
def testSaveAction(self, seeded, has_state): if not tf.executing_eagerly(): self.skipTest( 'b/129079730: PolicySaver does not work in TF1.x yet') if has_state: network = q_rnn_network.QRnnNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) else: network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) action_seed = 98723 saver = policy_saver.PolicySaver(policy, batch_size=None, use_nest_path_signatures=False, seed=action_seed) path = os.path.join(self.get_temp_dir(), 'save_model_action') saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.assertIn('action', reloaded.signatures) reloaded_action = reloaded.signatures['action'] self._compare_input_output_specs( reloaded_action, expected_input_specs=(self._time_step_spec, policy.policy_state_spec), expected_output_spec=policy.policy_step_spec, batch_input=True) batch_size = 3 action_inputs = tensor_spec.sample_spec_nest( (self._time_step_spec, policy.policy_state_spec), outer_dims=(batch_size, ), seed=4) function_action_input_dict = dict( (spec.name, value) for (spec, value) in zip( tf.nest.flatten((self._time_step_spec, policy.policy_state_spec )), tf.nest.flatten(action_inputs))) # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded model # are equal, which in addition to seeding the call to action() and # PolicySaver helps ensure equality of the output of action() in both cases. self.assertEqual(reloaded_action.graph.seed, self._global_seed) action_output = policy.action(*action_inputs, seed=action_seed) # The seed= argument for the SavedModel action call was given at creation of # the PolicySaver. # This is the flat-signature function. reloaded_action_output_dict = reloaded_action( **function_action_input_dict) def match_dtype_shape(x, y, msg=None): self.assertEqual(x.shape, y.shape, msg=msg) self.assertEqual(x.dtype, y.dtype, msg=msg) # This is the non-flat function. if has_state: reloaded_action_output = reloaded.action(*action_inputs) else: # Try both cases: one with an empty policy_state and one with no # policy_state. Compare them. # NOTE(ebrevdo): The first call to .action() must be stored in # reloaded_action_output because this is the version being compared later # against the true action_output and the values will change after the # first call due to randomness. reloaded_action_output = reloaded.action(*action_inputs) reloaded_action_output_no_input_state = reloaded.action( action_inputs[0]) # Even with a seed, multiple calls to action will get different values, # so here we just check the signature matches. tf.nest.map_structure(match_dtype_shape, reloaded_action_output_no_input_state, reloaded_action_output) action_output_dict = dict( ((spec.name, value) for (spec, value) in zip(tf.nest.flatten(policy.policy_step_spec), tf.nest.flatten(action_output)))) # Check output of the flattened signature call. action_output_dict = self.evaluate(action_output_dict) reloaded_action_output_dict = self.evaluate( reloaded_action_output_dict) self.assertAllEqual(action_output_dict.keys(), reloaded_action_output_dict.keys()) for k in action_output_dict: if seeded: self.assertAllClose(action_output_dict[k], reloaded_action_output_dict[k], msg='\nMismatched dict key: %s.' % k) else: match_dtype_shape(action_output_dict[k], reloaded_action_output_dict[k], msg='\nMismatch dict key: %s.' % k) # Check output of the proper structured call. action_output = self.evaluate(action_output) reloaded_action_output = self.evaluate(reloaded_action_output) # With non-signature functions, we can check that passing a seed does the # right thing the second time. if seeded: tf.nest.map_structure(self.assertAllClose, action_output, reloaded_action_output) else: tf.nest.map_structure(match_dtype_shape, action_output, reloaded_action_output)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, log_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # Note this is a python environment. env = batched_py_environment.BatchedPyEnvironment( [suite_gym.load(env_name)]) eval_py_env = suite_gym.load(env_name) # Convert specs to BoundedTensorSpec. action_spec = tensor_spec.from_spec(env.action_spec()) observation_spec = tensor_spec.from_spec(env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()), tensor_spec.from_spec(env.action_spec()), fc_layer_params=fc_layer_params) # The agent must be in graph. global_step = tf.compat.v1.train.get_or_create_global_step() agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_collect_policy = agent.collect_policy collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy) greedy_policy = py_tf_policy.PyTFPolicy(agent.policy) random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec()) # Python replay buffer. replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer( capacity=replay_buffer_capacity, data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec)) time_step = env.reset() # Initialize the replay buffer with some transitions. We use the random # policy to initialize the replay buffer to make sure we get a good # distribution of actions. for _ in range(initial_collect_steps): time_step = collect_step(env, time_step, random_policy, replay_buffer) # TODO(b/112041045) Use global_step as counter. train_checkpointer = common.Checkpointer(ckpt_dir=train_dir, agent=agent, global_step=global_step) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=agent.policy, global_step=global_step) ds = replay_buffer.as_dataset(sample_batch_size=batch_size, num_steps=2) ds = ds.prefetch(4) itr = tf.compat.v1.data.make_initializable_iterator(ds) experience = itr.get_next() train_op = common.function(agent.train)(experience) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() with tf.compat.v1.Session() as session: train_checkpointer.initialize_or_restore(session) common.initialize_uninitialized_variables(session) session.run(itr.initializer) # Copy critic network values to the target critic network. session.run(agent.initialize()) train = session.make_callable(train_op) global_step_call = session.make_callable(global_step) session.run(train_summary_writer.init()) session.run(eval_summary_writer.init()) # Compute inital evaluation metrics. global_step_val = global_step_call() metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) timed_at_step = global_step_val collect_time = 0 train_time = 0 steps_per_second_ph = tf.compat.v1.placeholder(tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() for _ in range(collect_steps_per_iteration): time_step = collect_step(env, time_step, collect_policy, replay_buffer) collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss = train() train_time += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss.loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) session.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info('%.3f steps/sec', steps_per_sec) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) # Reset timing to avoid counting eval time. timed_at_step = global_step_val start_time = time.time()
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, train_sequence_length=1, # Params for QNetwork fc_layer_params=(100, ), # Params for QRnnNetwork input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) if train_sequence_length > 1: q_net = q_rnn_network.QRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) else: q_net = q_network.QNetwork(tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=fc_layer_params) # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.contrib.summary.scalar(name='global_steps/sec', tensor=steps_per_sec) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss