def test_network_can_preprocess_and_combine(self): batch_size = 3 frames = 5 num_actions = 2 lstm_size = 6 states = (tf.random.uniform([batch_size, frames, 1]), tf.random.uniform([batch_size, frames])) preprocessing_layers = ( tf.keras.layers.Dense(4), tf.keras.Sequential([ expand_dims_layer.ExpandDims(-1), # Convert to vec size (1,). tf.keras.layers.Dense(4) ])) network = q_rnn_network.QRnnNetwork( input_tensor_spec=(tensor_spec.TensorSpec([1], tf.float32), tensor_spec.TensorSpec([], tf.float32)), preprocessing_layers=preprocessing_layers, preprocessing_combiner=tf.keras.layers.Add(), lstm_size=(lstm_size, ), action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, num_actions - 1)) empty_step_type = tf.constant([[time_step.StepType.FIRST] * frames] * batch_size) q_values, _ = network(states, empty_step_type) self.assertAllEqual(q_values.shape.as_list(), [batch_size, frames, num_actions]) # At least 2 variables each for the preprocessing layers. self.assertGreater(len(network.trainable_variables), 4)
def testTrainWithRNN(self): # Emits trajectories shaped (batch=1, time=6, ...) traj, time_step_spec, action_spec = ( driver_test_utils.make_random_trajectory()) cloning_net = q_rnn_network.QRnnNetwork( time_step_spec.observation, action_spec) agent = behavioral_cloning_agent.BehavioralCloningAgent( time_step_spec, action_spec, cloning_network=cloning_net, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.01), num_outer_dims=2) # Disable clipping to make sure we can see the difference in behavior agent.policy._clip = False # Remove policy_info, as BehavioralCloningAgent expects none. traj = traj.replace(policy_info=()) # TODO(b/123883319) if tf.executing_eagerly(): train_and_loss = lambda: agent.train(traj) else: train_and_loss = agent.train(traj) replay = trajectory_replay.TrajectoryReplay(agent.policy) self.evaluate(tf.compat.v1.global_variables_initializer()) initial_actions = self.evaluate(replay.run(traj)[0]) for _ in range(TRAIN_ITERATIONS): self.evaluate(train_and_loss) post_training_actions = self.evaluate(replay.run(traj)[0]) # We don't necessarily converge to the same actions as in trajectory after # 10 steps of an untuned optimizer, but the policy does change. self.assertFalse(np.all(initial_actions == post_training_actions))
def testSaveGetInitialState(self): if not tf.executing_eagerly(): self.skipTest( 'b/129079730: PolicySaver does not work in TF1.x yet') network = q_rnn_network.QRnnNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) saver_nobatch = policy_saver.PolicySaver(policy, batch_size=None) path = os.path.join(self.get_temp_dir(), 'save_model_initial_state_nobatch') saver_nobatch.save(path) reloaded_nobatch = tf.compat.v2.saved_model.load(path) self.assertIn('get_initial_state', reloaded_nobatch.signatures) reloaded_get_initial_state = ( reloaded_nobatch.signatures['get_initial_state']) self._compare_input_output_specs( reloaded_get_initial_state, expected_input_specs=(tf.TensorSpec(dtype=tf.int32, shape=(), name='batch_size'), ), expected_output_spec=policy.policy_state_spec, batch_input=False, batch_size=None) initial_state = policy.get_initial_state(batch_size=3) initial_state = self.evaluate(initial_state) reloaded_nobatch_initial_state = reloaded_nobatch.get_initial_state( batch_size=3) reloaded_nobatch_initial_state = self.evaluate( reloaded_nobatch_initial_state) tf.nest.map_structure(self.assertAllClose, initial_state, reloaded_nobatch_initial_state) saver_batch = policy_saver.PolicySaver(policy, batch_size=3) path = os.path.join(self.get_temp_dir(), 'save_model_initial_state_batch') saver_batch.save(path) reloaded_batch = tf.compat.v2.saved_model.load(path) self.assertIn('get_initial_state', reloaded_batch.signatures) reloaded_get_initial_state = reloaded_batch.signatures[ 'get_initial_state'] self._compare_input_output_specs( reloaded_get_initial_state, expected_input_specs=(), expected_output_spec=policy.policy_state_spec, batch_input=False, batch_size=3) reloaded_batch_initial_state = reloaded_batch.get_initial_state() reloaded_batch_initial_state = self.evaluate( reloaded_batch_initial_state) tf.nest.map_structure(self.assertAllClose, initial_state, reloaded_batch_initial_state)
def testTrainWithRNN(self): # Hard code a trajectory shaped (time=6, batch=1, ...). traj, time_step_spec, action_spec = create_arbitrary_trajectory() cloning_net = q_rnn_network.QRnnNetwork( time_step_spec.observation, action_spec) agent = behavioral_cloning_agent.BehavioralCloningAgent( time_step_spec, action_spec, cloning_network=cloning_net, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.01), num_outer_dims=2) # Disable clipping to make sure we can see the difference in behavior agent.policy._clip = False # Remove policy_info, as BehavioralCloningAgent expects none. traj = traj.replace(policy_info=()) # TODO(b/123883319) if tf.executing_eagerly(): train_and_loss = lambda: agent.train(traj) else: train_and_loss = agent.train(traj) self.evaluate(tf.compat.v1.global_variables_initializer()) initial_loss = self.evaluate(train_and_loss).loss for _ in range(TRAIN_ITERATIONS - 1): loss = self.evaluate(train_and_loss).loss # We don't necessarily converge to the same actions as in trajectory after # 10 steps of an untuned optimizer, but the loss should go down. self.assertGreater(initial_loss, loss)
def test_network_can_preprocess_and_combine_no_time_dim(self): if tf.executing_eagerly(): self.skipTest('b/123776211') batch_size = 3 num_actions = 2 lstm_size = 5 states = (tf.random.uniform([batch_size, 1]), tf.random.uniform([batch_size])) preprocessing_layers = ( tf.keras.layers.Dense(4), tf.keras.Sequential([ expand_dims_layer.ExpandDims(-1), # Convert to vec size (1,). tf.keras.layers.Dense(4) ])) network = q_rnn_network.QRnnNetwork( input_tensor_spec=(tensor_spec.TensorSpec([1], tf.float32), tensor_spec.TensorSpec([], tf.float32)), preprocessing_layers=preprocessing_layers, preprocessing_combiner=tf.keras.layers.Add(), lstm_size=(lstm_size, ), action_spec=tensor_spec.BoundedTensorSpec([1], tf.int32, 0, num_actions - 1)) empty_step_type = tf.constant([time_step.StepType.FIRST] * batch_size) q_values, _ = network(states, empty_step_type) # Processed 1 time step and the time axis was squeezed back. self.assertAllEqual(q_values.shape.as_list(), [batch_size, num_actions]) # At least 2 variables each for the preprocessing layers. self.assertGreater(len(network.trainable_variables), 4)
def test_network_builds(self): env = suite_gym.load('CartPole-v0') tf_env = tf_py_environment.TFPyEnvironment(env) rnn_network = q_rnn_network.QRnnNetwork(tf_env.observation_spec(), tf_env.action_spec()) time_step = tf_env.current_time_step() q_values, state = rnn_network(time_step.observation, time_step.step_type) self.assertEqual((1, 2), q_values.shape) self.assertEqual((1, 40), state[0].shape) self.assertEqual((1, 40), state[1].shape)
def test_network_builds(self): env = suite_gym.load('CartPole-v0') tf_env = tf_py_environment.TFPyEnvironment(env) rnn_network = q_rnn_network.QRnnNetwork(tf_env.observation_spec(), tf_env.action_spec(), lstm_size=(40, )) first_time_step = tf_env.current_time_step() q_values, state = rnn_network( first_time_step.observation, first_time_step.step_type, network_state=rnn_network.get_initial_state(batch_size=1)) self.assertEqual((1, 2), q_values.shape) self.assertEqual((1, 40), state[0].shape) self.assertEqual((1, 40), state[1].shape)
def test_network_builds_rnn_construction_fn(self): env = suite_gym.load('CartPole-v0') tf_env = tf_py_environment.TFPyEnvironment(env) rnn_network = q_rnn_network.QRnnNetwork(tf_env.observation_spec(), tf_env.action_spec(), rnn_construction_fn=rnn_keras_fn, rnn_construction_kwargs={'lstm_size': 3}) first_time_step = tf_env.current_time_step() q_values, state = rnn_network( first_time_step.observation, first_time_step.step_type, network_state=rnn_network.get_initial_state(batch_size=1), ) self.assertEqual((1, 2), q_values.shape) self.assertEqual((3,), state[0].shape)
def test_network_builds_stacked_cells(self): env = suite_gym.load('CartPole-v0') tf_env = tf_py_environment.TFPyEnvironment(env) rnn_network = q_rnn_network.QRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), lstm_size=(10, 5)) time_step = tf_env.current_time_step() q_values, state = rnn_network(time_step.observation, time_step.step_type) self.assertTrue(isinstance(state, tuple)) self.assertEqual(2, len(state)) self.assertEqual((1, 2), q_values.shape) self.assertEqual((1, 10), state[0][0].shape) self.assertEqual((1, 10), state[0][1].shape) self.assertEqual((1, 5), state[1][0].shape) self.assertEqual((1, 5), state[1][1].shape)
def setUp(self): super(PolicySaverTest, self).setUp() observation_spec = tf.TensorSpec(dtype=tf.int64, shape=(), name='callee_users') self._time_step_spec = time_step.time_step_spec(observation_spec) self._action_spec = tensor_spec.BoundedTensorSpec( dtype=tf.int64, shape=(), minimum=0, maximum=1, name='inlining_decision') self._network = q_rnn_network.QRnnNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec, lstm_size=(40, ))
def __init__(self, config, *args, **kwargs): """Initialize the agent.""" self.observation_spec = config['observation_spec'] self.time_step_spec = ts.time_step_spec(self.observation_spec) self.action_spec = config['action_spec'] self.environment_batch_size = config['environment_batch_size'] self.q_net = q_rnn_network.QRnnNetwork( self.observation_spec, self.action_spec, input_fc_layer_params=(256, ), lstm_size=(256, 256)) self.optimizer = tf.keras.optimizers.Adam() self.agent = dqn_agent.DqnAgent( self.time_step_spec, self.action_spec, q_network=self.q_net, optimizer=self.optimizer, epsilon_greedy=0.3) self.agent.initialize() self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=self.agent.collect_data_spec, batch_size=1, dataset_drop_remainder=True) self.dataset = self.replay_buffer.as_dataset( sample_batch_size=1, num_steps=2, single_deterministic_pass=True) self.intention_classifier = tf.keras.models.Sequential( [tf.keras.layers.Dense(1, input_shape=self.observation_spec.shape, activation='sigmoid')]) self.intention_classifier.compile(optimizer=self.optimizer, loss=tf.keras.losses.BinaryCrossentropy()) self.intention_dataset_input = [] self.intention_dataset_true = [] self.num_hinted = 0
def test_network_builds_stacked_cells(self): env = suite_gym.load('CartPole-v0') tf_env = tf_py_environment.TFPyEnvironment(env) rnn_network = q_rnn_network.QRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), lstm_size=(10, 5)) first_time_step = tf_env.current_time_step() q_values, state = rnn_network( first_time_step.observation, first_time_step.step_type, network_state=rnn_network.get_initial_state(batch_size=1) ) tf.nest.assert_same_structure(rnn_network.state_spec, state) self.assertEqual(2, len(state)) self.assertEqual((1, 2), q_values.shape) self.assertEqual((1, 10), state[0][0].shape) self.assertEqual((1, 10), state[0][1].shape) self.assertEqual((1, 5), state[1][0].shape) self.assertEqual((1, 5), state[1][1].shape)
def setUp(self): observation_spec = { 'inlining_default': tf.TensorSpec(dtype=tf.int64, shape=(), name='inlining_default') } self._time_step_spec = time_step.time_step_spec(observation_spec) self._action_spec = tensor_spec.BoundedTensorSpec( dtype=tf.int64, shape=(), minimum=0, maximum=1, name='inlining_decision') self._network = q_rnn_network.QRnnNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec, lstm_size=(40, ), preprocessing_layers={ 'inlining_default': tf.keras.layers.Lambda(lambda x: x) }) super(TrainerTest, self).setUp()
def testTrainWithRNN(self): # Emits trajectories shaped (batch=1, time=6, ...) traj, time_step_spec, action_spec = ( driver_test_utils.make_random_trajectory()) cloning_net = q_rnn_network.QRnnNetwork(time_step_spec.observation, action_spec) agent = behavioral_cloning_agent.BehavioralCloningAgent( time_step_spec, action_spec, cloning_network=cloning_net, optimizer=tf.train.AdamOptimizer(learning_rate=0.01)) # Remove policy_info, as BehavioralCloningAgent expects none. traj = traj.replace(policy_info=()) train_and_loss = agent.train(traj) replay = trajectory_replay.TrajectoryReplay(agent.policy()) self.evaluate(tf.global_variables_initializer()) initial_actions = self.evaluate(replay.run(traj)[0]) for _ in range(TRAIN_ITERATIONS): self.evaluate(train_and_loss) post_training_actions = self.evaluate(replay.run(traj)[0]) # We don't necessarily converge to the same actions as in trajectory after # 10 steps of an untuned optimizer, but the policy does change. self.assertFalse(np.all(initial_actions == post_training_actions))
def testSaveAction(self, seeded, has_state): if not tf.executing_eagerly(): self.skipTest( 'b/129079730: PolicySaver does not work in TF1.x yet') if has_state: network = q_rnn_network.QRnnNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) else: network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) action_seed = 98723 saver = policy_saver.PolicySaver(policy, batch_size=None, use_nest_path_signatures=False, seed=action_seed) path = os.path.join(self.get_temp_dir(), 'save_model_action') saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.assertIn('action', reloaded.signatures) reloaded_action = reloaded.signatures['action'] self._compare_input_output_specs( reloaded_action, expected_input_specs=(self._time_step_spec, policy.policy_state_spec), expected_output_spec=policy.policy_step_spec, batch_input=True) batch_size = 3 action_inputs = tensor_spec.sample_spec_nest( (self._time_step_spec, policy.policy_state_spec), outer_dims=(batch_size, ), seed=4) function_action_input_dict = dict( (spec.name, value) for (spec, value) in zip( tf.nest.flatten((self._time_step_spec, policy.policy_state_spec )), tf.nest.flatten(action_inputs))) # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded model # are equal, which in addition to seeding the call to action() and # PolicySaver helps ensure equality of the output of action() in both cases. self.assertEqual(reloaded_action.graph.seed, self._global_seed) action_output = policy.action(*action_inputs, seed=action_seed) # The seed= argument for the SavedModel action call was given at creation of # the PolicySaver. # This is the flat-signature function. reloaded_action_output_dict = reloaded_action( **function_action_input_dict) def match_dtype_shape(x, y, msg=None): self.assertEqual(x.shape, y.shape, msg=msg) self.assertEqual(x.dtype, y.dtype, msg=msg) # This is the non-flat function. if has_state: reloaded_action_output = reloaded.action(*action_inputs) else: # Try both cases: one with an empty policy_state and one with no # policy_state. Compare them. # NOTE(ebrevdo): The first call to .action() must be stored in # reloaded_action_output because this is the version being compared later # against the true action_output and the values will change after the # first call due to randomness. reloaded_action_output = reloaded.action(*action_inputs) reloaded_action_output_no_input_state = reloaded.action( action_inputs[0]) # Even with a seed, multiple calls to action will get different values, # so here we just check the signature matches. tf.nest.map_structure(match_dtype_shape, reloaded_action_output_no_input_state, reloaded_action_output) action_output_dict = dict( ((spec.name, value) for (spec, value) in zip(tf.nest.flatten(policy.policy_step_spec), tf.nest.flatten(action_output)))) # Check output of the flattened signature call. action_output_dict = self.evaluate(action_output_dict) reloaded_action_output_dict = self.evaluate( reloaded_action_output_dict) self.assertAllEqual(action_output_dict.keys(), reloaded_action_output_dict.keys()) for k in action_output_dict: if seeded: self.assertAllClose(action_output_dict[k], reloaded_action_output_dict[k], msg='\nMismatched dict key: %s.' % k) else: match_dtype_shape(action_output_dict[k], reloaded_action_output_dict[k], msg='\nMismatch dict key: %s.' % k) # Check output of the proper structured call. action_output = self.evaluate(action_output) reloaded_action_output = self.evaluate(reloaded_action_output) # With non-signature functions, we can check that passing a seed does the # right thing the second time. if seeded: tf.nest.map_structure(self.assertAllClose, action_output, reloaded_action_output) else: tf.nest.map_structure(match_dtype_shape, action_output, reloaded_action_output)
def train(): summary_interval = 1000 summaries_flush_secs = 10 num_eval_episodes = 5 root_dir = '/tmp/tensorflow/logs/tfenv01' train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs*1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs*1000) # maybe py_metrics? eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] environment = TradeEnvironment() # utils.validate_py_environment(environment, episodes=5) # Environments global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): train_env = tf_py_environment.TFPyEnvironment(environment) eval_env = tf_py_environment.TFPyEnvironment(environment) num_iterations = 50 fc_layer_params = (512, ) # ~ (17 + 1001) / 2 input_fc_layer_params = (50, ) output_fc_layer_params = (20, ) lstm_size = (30, ) initial_collect_steps = 20 collect_steps_per_iteration = 1 collect_episodes_per_iteration = 1 # the same as above batch_size = 64 replay_buffer_capacity = 10000 train_sequence_length = 10 gamma = 0.99 # check if 1.0 works as well target_update_tau = 0.05 target_update_period = 5 epsilon_greedy = 0.1 gradient_clipping = None reward_scale_factor = 1.0 learning_rate = 1e-2 log_interval = 30 eval_interval = 15 # train_env.observation_spec(), q_net = q_rnn_network.QRnnNetwork( train_env.time_step_spec().observation, train_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params, ) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) tf_agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=global_step, ) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=train_env.batch_size, max_length=replay_buffer_capacity, ) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Policy which does not allow some actions in certain states q_policy = FilteredQPolicy( tf_agent._time_step_spec, tf_agent._action_spec, q_network=tf_agent._q_network, ) # Valid policy to pre-fill replay buffer initial_collect_policy = DummyTradePolicy( train_env.time_step_spec(), train_env.action_spec(), ) print('Initial collecting...') initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( train_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps, ).run() # Main agent's policy; greedy one policy = greedy_policy.GreedyPolicy(q_policy) # Policy used for evaluation, the same as above eval_policy = greedy_policy.GreedyPolicy(q_policy) tf_agent._policy = policy collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy( q_policy, epsilon=tf_agent._epsilon_greedy) # Patch random policy for epsilon greedy collect policy filtered_random_tf_policy = FilteredRandomTFPolicy( time_step_spec=policy.time_step_spec, action_spec=policy.action_spec, ) collect_policy._random_policy = filtered_random_tf_policy tf_agent._collect_policy = collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( train_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration, ).run() dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length+1, ).prefetch(3) iterator = iter(dataset) experience, _ = next(iterator) loss_info = common.function(tf_agent.train)(experience=experience) # Checkpoints train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step, ) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer, ) summary_ops = [] for train_metric in train_metrics: summary_ops.append(train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2], )) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # sess.run(train_summary_writer.init()) # sess.run(eval_summary_writer.init()) # Initialize the graph # tfe.Saver().restore() # train_checkpointer.initialize_or_restore() # rb_checkpointer.initialize_or_restore() # sess.run(iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) print('Collecting initial experience...') sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([loss_info, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step, ) # Train for i in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_inerval == 0: print('step=%d, loss=%f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val-timed_at_step) / time_acc print('%.3f steps/sec', steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}, ) timed_at_step = global_step_val time_acc = 0 # Save checkpoints if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) # Evaluate if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) print('Done!')
def load_agents_and_create_videos(root_dir, env_name='CartPole-v0', num_iterations=NUM_ITERATIONS, max_ep_steps=1000, train_sequence_length=1, # Params for QNetwork fc_layer_params=((128,64,32)), # Params for QRnnNetwork input_fc_layer_params=(50,), lstm_size=(20,), output_fc_layer_params=(20,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=10000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, num_random_episodes=1, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None, random_metrics_callback=None): train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') random_dir = os.path.join(root_dir, 'random') eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)] global_step = tf.compat.v1.train.get_or_create_global_step() # Match the environments used in training tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name, max_episode_steps=max_ep_steps)) eval_py_env = suite_gym.load(env_name, max_episode_steps=max_ep_steps) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) if train_sequence_length != 1 and n_step_update != 1: raise NotImplementedError( 'train_eval does not currently support n-step updates with stateful ' 'networks (i.e., RNNs)') if train_sequence_length > 1: q_net = q_rnn_network.QRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) else: q_net = q_network.QNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=fc_layer_params) train_sequence_length = n_step_update # Match the agents used in training tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(),] eval_policy = tf_agent.policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) # Load the data from training train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() # Define a random policy for comparison random_policy = random_tf_policy.RandomTFPolicy(eval_tf_env.time_step_spec(), eval_tf_env.action_spec()) # Make movies of the trained agent and a random agent date_string = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S') trained_filename = "trained-agent" + date_string create_policy_eval_video(eval_tf_env, eval_py_env, tf_agent.policy, trained_filename) random_filename = 'random-agent ' + date_string create_policy_eval_video(eval_tf_env, eval_py_env, random_policy, random_filename)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, train_sequence_length=1, # Params for QNetwork fc_layer_params=(100, ), # Params for QRnnNetwork input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) if train_sequence_length != 1 and n_step_update != 1: raise NotImplementedError( 'train_eval does not currently support n-step updates with stateful ' 'networks (i.e., RNNs)') if train_sequence_length > 1: q_net = q_rnn_network.QRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) else: q_net = q_network.QNetwork(tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=fc_layer_params) train_sequence_length = n_step_update # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, experiment_name, # experiment name env_name='carla-v0', agent_name='sac', # agent's name num_iterations=int(1e7), actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), model_network_ctor_type='non-hierarchical', # model net input_names=['camera', 'lidar'], # names for inputs mask_names=['birdeye'], # names for masks preprocessing_combiner=tf.keras.layers.Add( ), # takes a flat list of tensors and combines them actor_lstm_size=(40, ), # lstm size for actor critic_lstm_size=(40, ), # lstm size for critic actor_output_fc_layers=(100, ), # lstm output critic_output_fc_layers=(100, ), # lstm output epsilon_greedy=0.1, # exploration parameter for DQN q_learning_rate=1e-3, # q learning rate for DQN ou_stddev=0.2, # exploration paprameter for DDPG ou_damping=0.15, # exploration parameter for DDPG dqda_clipping=None, # for DDPG exploration_noise_std=0.1, # exploration paramter for td3 actor_update_period=2, # for td3 # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=int(1e5), # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, initial_model_train_steps=100000, # initial model training batch_size=256, model_batch_size=32, # model training batch size sequence_length=4, # number of timesteps to train model actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, model_learning_rate=1e-4, # learning rate for model training td_errors_loss_fn=tf.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging num_images_per_summary=1, # images for each summary train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, gpu_allow_growth=True, # GPU memory growth gpu_memory_limit=None, # GPU memory limit action_repeat=1 ): # Name of single observation channel, ['camera', 'lidar', 'birdeye'] # Setup GPU gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit) ]) # Get train and eval directories root_dir = os.path.expanduser(root_dir) root_dir = os.path.join(root_dir, env_name, experiment_name) # Get summary writers summary_writer = tf.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() # Eval metrics eval_metrics = [ tf_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthEvalPolicy', buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() # Whether to record for summary with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create Carla environment if agent_name == 'latent_sac': py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names + mask_names, action_repeat=action_repeat) elif agent_name == 'dqn': py_env, eval_py_env = load_carla_env(env_name='carla-v0', discrete=True, obs_channels=input_names, action_repeat=action_repeat) else: py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names, action_repeat=action_repeat) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) fps = int(np.round(1.0 / (py_env.dt * action_repeat))) # Specs time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() ## Make tf agent if agent_name == 'latent_sac': # Get model network for latent sac if model_network_ctor_type == 'hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical elif model_network_ctor_type == 'non-hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical else: raise NotImplementedError model_net = model_network_ctor(input_names, input_names + mask_names) # Get the latent spec latent_size = model_net.latent_size latent_observation_spec = tensor_spec.TensorSpec((latent_size, ), dtype=tf.float32) latent_time_step_spec = ts.time_step_spec( observation_spec=latent_observation_spec) # Get actor and critic net actor_net = actor_distribution_network.ActorDistributionNetwork( latent_observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (latent_observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) # Build the inner SAC agent based on latent space inner_agent = sac_agent.SacAgent( latent_time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) inner_agent.initialize() # Build the latent sac agent tf_agent = latent_sac_agent.LatentSACAgent( time_step_spec, action_spec, inner_agent=inner_agent, model_network=model_net, model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), model_batch_size=model_batch_size, num_images_per_summary=num_images_per_summary, sequence_length=sequence_length, gradient_clipping=gradient_clipping, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, fps=fps) else: # Set up preprosessing layers for dictionary observation inputs preprocessing_layers = collections.OrderedDict() for name in input_names: preprocessing_layers[name] = Preprocessing_Layer(32, 256) if len(input_names) < 2: preprocessing_combiner = None if agent_name == 'dqn': q_rnn_net = q_rnn_network.QRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_rnn_net, epsilon_greedy=epsilon_greedy, n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=q_learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'ddpg' or agent_name == 'td3': actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) if agent_name == 'ddpg': tf_agent = ddpg_agent.DdpgAgent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) elif agent_name == 'td3': tf_agent = td3_agent.Td3Agent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'sac': actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=normal_projection_net) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_distribution_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math. squared_difference, # make critic loss dimension compatible gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) else: raise NotImplementedError tf_agent.initialize() # Get replay buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, # No parallel environments max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] # Train metrics env_steps = tf_metrics.EnvironmentSteps() average_return = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] # Get policies # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) eval_policy = tf_agent.policy initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) collect_policy = tf_agent.collect_policy # Checkpointers train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=2) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step, max_to_keep=2) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() # Collect driver initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) # Optimize the performance by using tf functions initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if (env_steps.result() == 0 or replay_buffer.num_frames() == 0): logging.info( 'Initializing replay buffer by collecting experience for %d steps' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() if agent_name == 'latent_sac': compute_summaries(eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=1, num_episodes_to_render=1, model_net=model_net, fps=10, image_keys=input_names + mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=1, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Dataset generates trajectories with shape [Bxslx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=sequence_length + 1).prefetch(3) iterator = iter(dataset) # Get train step def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) train_step = common.function(train_step) if agent_name == 'latent_sac': def train_model_step(): experience, _ = next(iterator) return tf_agent.train_model(experience) train_model_step = common.function(train_model_step) # Training initializations time_step = None time_acc = 0 env_steps_before = env_steps.result().numpy() # Start training for iteration in range(num_iterations): start_time = time.time() if agent_name == 'latent_sac' and iteration < initial_model_train_steps: train_model_step() else: # Run collect time_step, _ = collect_driver.run(time_step=time_step) # Train an iteration for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time # Log training information if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() # Get training metrics for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) # Evaluation if global_step.numpy() % eval_interval == 0: # Log evaluation metrics if agent_name == 'latent_sac': compute_summaries( eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, model_net=model_net, fps=10, image_keys=input_names + mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Save checkpoints global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def testSaveAction(self): if not tf.executing_eagerly(): self.skipTest( 'b/129079730: PolicySaver does not work in TF1.x yet') q_network = q_rnn_network.QRnnNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=q_network) action_seed = 98723 saver = policy_saver.PolicySaver(policy, batch_size=None, seed=action_seed) path = os.path.join(tf.compat.v1.test.get_temp_dir(), 'save_model_action') saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.assertIn('action', reloaded.signatures) reloaded_action = reloaded.signatures['action'] self._compare_input_output_specs( reloaded_action, expected_input_specs=(self._time_step_spec, policy.policy_state_spec), expected_output_spec=policy.policy_step_spec, batch_input=True) batch_size = 3 action_inputs = tensor_spec.sample_spec_nest( (self._time_step_spec, policy.policy_state_spec), outer_dims=(batch_size, ), seed=4) function_action_input_dict = dict( (spec.name, value) for (spec, value) in zip( tf.nest.flatten((self._time_step_spec, policy.policy_state_spec )), tf.nest.flatten(action_inputs))) # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded model # are equal, which in addition to seeding the call to action() and # PolicySaver helps ensure equality of the output of action() in both cases. self.assertEqual(reloaded_action.graph.seed, self._global_seed) action_output = policy.action(*action_inputs, seed=action_seed) # The seed= argument for the SavedModel action call was given at creation of # the PolicySaver. reloaded_action_output_dict = reloaded_action( **function_action_input_dict) action_output_dict = dict( ((spec.name, value) for (spec, value) in zip(tf.nest.flatten(policy.policy_step_spec), tf.nest.flatten(action_output)))) action_output_dict = self.evaluate(action_output_dict) reloaded_action_output_dict = self.evaluate( reloaded_action_output_dict) self.assertAllEqual(action_output_dict.keys(), reloaded_action_output_dict.keys()) for k in action_output_dict: self.assertAllClose(action_output_dict[k], reloaded_action_output_dict[k], msg='\nMismatched dict key: %s.' % k)
env_name = 'Breakout-v0' train_py_env = BreakoutEnv(suite_gym.load(env_name)) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_py_env = BreakoutEnv(suite_gym.load(env_name)) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) conv_layer_params = [(32, 8, 4), (64, 4, 2)] input_fc_layer_params = (256, ) output_fc_layer_params = (256, ) lstm_size = [256] q_net = q_rnn_network.QRnnNetwork( train_env.observation_spec(), train_env.action_spec(), input_fc_layer_params=input_fc_layer_params, output_fc_layer_params=output_fc_layer_params, conv_layer_params=conv_layer_params, lstm_size=lstm_size) optimizer = tf.optimizers.Adam(learning_rate=learning_rate) loss_func = common.element_wise_huber_loss train_step_counter = tf.Variable(0) agent = dqn_agent.DqnAgent(train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, td_errors_loss_fn=loss_func, train_step_counter=train_step_counter) agent.initialize()
def testSaveAction(self, seeded, has_state, distribution_net, has_input_fn_and_spec): with tf.compat.v1.Graph().as_default(): tf.compat.v1.set_random_seed(self._global_seed) with tf.compat.v1.Session().as_default(): global_step = common.create_variable('train_step', initial_value=0) if distribution_net: network = actor_distribution_network.ActorDistributionNetwork( self._time_step_spec.observation, self._action_spec) policy = actor_policy.ActorPolicy( time_step_spec=self._time_step_spec, action_spec=self._action_spec, actor_network=network) else: if has_state: network = q_rnn_network.QRnnNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec, lstm_size=(40,)) else: network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy( time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) action_seed = 98723 batch_size = 3 action_inputs = tensor_spec.sample_spec_nest( (self._time_step_spec, policy.policy_state_spec), outer_dims=(batch_size,), seed=4) action_input_values = self.evaluate(action_inputs) action_input_tensors = tf.nest.map_structure(tf.convert_to_tensor, action_input_values) action_output = policy.action(*action_input_tensors, seed=action_seed) distribution_output = policy.distribution(*action_input_tensors) self.assertIsInstance( distribution_output.action, tfp.distributions.Distribution) self.evaluate(tf.compat.v1.global_variables_initializer()) action_output_dict = collections.OrderedDict( ((spec.name, value) for (spec, value) in zip( tf.nest.flatten(policy.policy_step_spec), tf.nest.flatten(action_output)))) # Check output of the flattened signature call. (action_output_value, action_output_dict) = self.evaluate( (action_output, action_output_dict)) distribution_output_value = self.evaluate(_sample_from_distributions( distribution_output)) input_fn_and_spec = None if has_input_fn_and_spec: input_fn_and_spec = (_convert_string_vector_to_action_input, tf.TensorSpec((7,), tf.string, name='example')) saver = policy_saver.PolicySaver( policy, batch_size=None, use_nest_path_signatures=False, seed=action_seed, input_fn_and_spec=input_fn_and_spec, train_step=global_step) path = os.path.join(self.get_temp_dir(), 'save_model_action') saver.save(path) with tf.compat.v1.Graph().as_default(): tf.compat.v1.set_random_seed(self._global_seed) with tf.compat.v1.Session().as_default(): reloaded = tf.compat.v2.saved_model.load(path) self.assertIn('action', reloaded.signatures) reloaded_action = reloaded.signatures['action'] if has_input_fn_and_spec: self._compare_input_output_specs( reloaded_action, expected_input_specs=input_fn_and_spec[1], expected_output_spec=policy.policy_step_spec, batch_input=True) else: self._compare_input_output_specs( reloaded_action, expected_input_specs=(self._time_step_spec, policy.policy_state_spec), expected_output_spec=policy.policy_step_spec, batch_input=True) # Reload action_input_values as tensors in the new graph. action_input_tensors = tf.nest.map_structure(tf.convert_to_tensor, action_input_values) action_input_spec = (self._time_step_spec, policy.policy_state_spec) function_action_input_dict = collections.OrderedDict( (spec.name, value) for (spec, value) in zip( tf.nest.flatten(action_input_spec), tf.nest.flatten(action_input_tensors))) # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded # model are equal, which in addition to seeding the call to action() and # PolicySaver helps ensure equality of the output of action() in both # cases. self.assertEqual(reloaded_action.graph.seed, self._global_seed) # The seed= argument for the SavedModel action call was given at # creation of the PolicySaver. if has_input_fn_and_spec: action_string_vector = _convert_action_input_to_string_vector( action_input_tensors) action_string_vector_values = self.evaluate(action_string_vector) reloaded_action_output_dict = reloaded_action(action_string_vector) reloaded_action_output = reloaded.action(action_string_vector) reloaded_distribution_output = reloaded.distribution( action_string_vector) self.assertIsInstance(reloaded_distribution_output.action, tfp.distributions.Distribution) else: # This is the flat-signature function. reloaded_action_output_dict = reloaded_action( **function_action_input_dict) # This is the non-flat function. reloaded_action_output = reloaded.action(*action_input_tensors) reloaded_distribution_output = reloaded.distribution( *action_input_tensors) self.assertIsInstance(reloaded_distribution_output.action, tfp.distributions.Distribution) if not has_state: # Try both cases: one with an empty policy_state and one with no # policy_state. Compare them. # NOTE(ebrevdo): The first call to .action() must be stored in # reloaded_action_output because this is the version being compared # later against the true action_output and the values will change # after the first call due to randomness. reloaded_action_output_no_input_state = reloaded.action( action_input_tensors[0]) reloaded_distribution_output_no_input_state = reloaded.distribution( action_input_tensors[0]) # Even with a seed, multiple calls to action will get different # values, so here we just check the signature matches. self.assertIsInstance( reloaded_distribution_output_no_input_state.action, tfp.distributions.Distribution) tf.nest.map_structure(self.match_dtype_shape, reloaded_action_output_no_input_state, reloaded_action_output) tf.nest.map_structure( self.match_dtype_shape, _sample_from_distributions( reloaded_distribution_output_no_input_state), _sample_from_distributions(reloaded_distribution_output)) self.evaluate(tf.compat.v1.global_variables_initializer()) (reloaded_action_output_dict, reloaded_action_output_value) = self.evaluate( (reloaded_action_output_dict, reloaded_action_output)) reloaded_distribution_output_value = self.evaluate( _sample_from_distributions(reloaded_distribution_output)) self.assertAllEqual(action_output_dict.keys(), reloaded_action_output_dict.keys()) for k in action_output_dict: if seeded: self.assertAllClose( action_output_dict[k], reloaded_action_output_dict[k], msg='\nMismatched dict key: %s.' % k) else: self.match_dtype_shape( action_output_dict[k], reloaded_action_output_dict[k], msg='\nMismatch dict key: %s.' % k) # With non-signature functions, we can check that passing a seed does # the right thing the second time. if seeded: tf.nest.map_structure(self.assertAllClose, action_output_value, reloaded_action_output_value) else: tf.nest.map_structure(self.match_dtype_shape, action_output_value, reloaded_action_output_value) tf.nest.map_structure(self.assertAllClose, distribution_output_value, reloaded_distribution_output_value) ## TFLite tests. # The converter must run outside of a TF1 graph context, even in # eager mode, to ensure the TF2 path is being executed. Only # works in TF2. if tf.compat.v1.executing_eagerly_outside_functions(): tflite_converter = tf.lite.TFLiteConverter.from_saved_model( path, signature_keys=['action']) tflite_converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # TODO(b/111309333): Remove this when `has_input_fn_and_spec` # is `False` once TFLite has native support for RNG ops, atan, etc. tf.lite.OpsSet.SELECT_TF_OPS, ] tflite_serialized_model = tflite_converter.convert() tflite_interpreter = tf.lite.Interpreter( model_content=tflite_serialized_model) tflite_runner = tflite_interpreter.get_signature_runner('action') tflite_signature = tflite_interpreter.get_signature_list()['action'] if has_input_fn_and_spec: tflite_action_input_dict = { 'example': action_string_vector_values, } else: tflite_action_input_dict = collections.OrderedDict( (spec.name, value) for (spec, value) in zip( tf.nest.flatten(action_input_spec), tf.nest.flatten(action_input_values))) self.assertEqual( set(tflite_signature['inputs']), set(tflite_action_input_dict)) self.assertEqual( set(tflite_signature['outputs']), set(action_output_dict)) tflite_output = tflite_runner(**tflite_action_input_dict) self.assertAllClose(tflite_output, action_output_dict)
def testSaveAction(self, seeded, has_state, distribution_net, has_input_fn_and_spec): with tf.compat.v1.Graph().as_default(): tf.compat.v1.set_random_seed(self._global_seed) with tf.compat.v1.Session().as_default(): global_step = common.create_variable('train_step', initial_value=0) if distribution_net: network = actor_distribution_network.ActorDistributionNetwork( self._time_step_spec.observation, self._action_spec) policy = actor_policy.ActorPolicy( time_step_spec=self._time_step_spec, action_spec=self._action_spec, actor_network=network) else: if has_state: network = q_rnn_network.QRnnNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) else: network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy( time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) action_seed = 98723 batch_size = 3 action_inputs = tensor_spec.sample_spec_nest( (self._time_step_spec, policy.policy_state_spec), outer_dims=(batch_size, ), seed=4) action_input_values = self.evaluate(action_inputs) action_input_tensors = tf.nest.map_structure( tf.convert_to_tensor, action_input_values) action_output = policy.action(*action_input_tensors, seed=action_seed) self.evaluate(tf.compat.v1.global_variables_initializer()) action_output_dict = dict(((spec.name, value) for ( spec, value) in zip(tf.nest.flatten(policy.policy_step_spec), tf.nest.flatten(action_output)))) # Check output of the flattened signature call. (action_output_value, action_output_dict) = self.evaluate( (action_output, action_output_dict)) input_fn_and_spec = None if has_input_fn_and_spec: input_fn_and_spec = ( self._convert_string_vector_to_action_input, tf.TensorSpec((7, ), tf.string, name='example')) saver = policy_saver.PolicySaver( policy, batch_size=None, use_nest_path_signatures=False, seed=action_seed, input_fn_and_spec=input_fn_and_spec, train_step=global_step) path = os.path.join(self.get_temp_dir(), 'save_model_action') saver.save(path) with tf.compat.v1.Graph().as_default(): tf.compat.v1.set_random_seed(self._global_seed) with tf.compat.v1.Session().as_default(): reloaded = tf.compat.v2.saved_model.load(path) self.assertIn('action', reloaded.signatures) reloaded_action = reloaded.signatures['action'] if has_input_fn_and_spec: self._compare_input_output_specs( reloaded_action, expected_input_specs=input_fn_and_spec[1], expected_output_spec=policy.policy_step_spec, batch_input=True) else: self._compare_input_output_specs( reloaded_action, expected_input_specs=(self._time_step_spec, policy.policy_state_spec), expected_output_spec=policy.policy_step_spec, batch_input=True) # Reload action_input_values as tensors in the new graph. action_input_tensors = tf.nest.map_structure( tf.convert_to_tensor, action_input_values) action_input_spec = (self._time_step_spec, policy.policy_state_spec) function_action_input_dict = dict( (spec.name, value) for (spec, value) in zip(tf.nest.flatten(action_input_spec), tf.nest.flatten(action_input_tensors))) # NOTE(ebrevdo): The graph-level seeds for the policy and the reloaded # model are equal, which in addition to seeding the call to action() and # PolicySaver helps ensure equality of the output of action() in both # cases. self.assertEqual(reloaded_action.graph.seed, self._global_seed) def match_dtype_shape(x, y, msg=None): self.assertEqual(x.shape, y.shape, msg=msg) self.assertEqual(x.dtype, y.dtype, msg=msg) # The seed= argument for the SavedModel action call was given at # creation of the PolicySaver. if has_input_fn_and_spec: action_string_vector = self._convert_action_input_to_string_vector( action_input_tensors) reloaded_action_output_dict = reloaded_action( action_string_vector) reloaded_action_output = reloaded.action( action_string_vector) else: # This is the flat-signature function. reloaded_action_output_dict = reloaded_action( **function_action_input_dict) # This is the non-flat function. reloaded_action_output = reloaded.action( *action_input_tensors) if not has_state: # Try both cases: one with an empty policy_state and one with no # policy_state. Compare them. # NOTE(ebrevdo): The first call to .action() must be stored in # reloaded_action_output because this is the version being compared # later against the true action_output and the values will change # after the first call due to randomness. reloaded_action_output_no_input_state = reloaded.action( action_input_tensors[0]) # Even with a seed, multiple calls to action will get different # values, so here we just check the signature matches. tf.nest.map_structure( match_dtype_shape, reloaded_action_output_no_input_state, reloaded_action_output) self.evaluate(tf.compat.v1.global_variables_initializer()) (reloaded_action_output_dict, reloaded_action_output_value) = self.evaluate( (reloaded_action_output_dict, reloaded_action_output)) self.assertAllEqual(action_output_dict.keys(), reloaded_action_output_dict.keys()) for k in action_output_dict: if seeded: self.assertAllClose(action_output_dict[k], reloaded_action_output_dict[k], msg='\nMismatched dict key: %s.' % k) else: match_dtype_shape(action_output_dict[k], reloaded_action_output_dict[k], msg='\nMismatch dict key: %s.' % k) # With non-signature functions, we can check that passing a seed does # the right thing the second time. if seeded: tf.nest.map_structure(self.assertAllClose, action_output_value, reloaded_action_output_value) else: tf.nest.map_structure(match_dtype_shape, action_output_value, reloaded_action_output_value)
def train(): global VERBOSE environment = TradeEnvironment() # utils.validate_py_environment(environment, episodes=5) # Environments train_env = tf_py_environment.TFPyEnvironment(environment) eval_env = tf_py_environment.TFPyEnvironment(environment) num_iterations = 50 fc_layer_params = (512, ) # ~ (17 + 1001) / 2 input_fc_layer_params = (17, ) output_fc_layer_params = (20, ) lstm_size = (17, ) initial_collect_steps = 20 collect_steps_per_iteration = 1 batch_size = 64 replay_buffer_capacity = 10000 gamma = 0.99 # check if 1 will work here target_update_tau = 0.05 target_update_period = 5 epsilon_greedy = 0.1 reward_scale_factor = 1.0 learning_rate = 1e-2 log_interval = 30 num_eval_episodes = 5 eval_interval = 15 # q_net = q_network.QNetwork( # train_env.observation_spec(), # train_env.action_spec(), # fc_layer_params=fc_layer_params, # ) q_net = q_rnn_network.QRnnNetwork( train_env.observation_spec(), train_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params, ) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) train_step_counter = tf.compat.v2.Variable(0) tf_agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, gamma=gamma, reward_scale_factor=reward_scale_factor, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, train_step_counter=train_step_counter, gradient_clipping=None, debug_summaries=False, summarize_grads_and_vars=False, ) q_policy = FilteredQPolicy( tf_agent._time_step_spec, tf_agent._action_spec, q_network=tf_agent._q_network, ) # Valid policy to pre-fill replay buffer dummy_policy = DummyTradePolicy( train_env.time_step_spec(), train_env.action_spec(), ) # Main agent's policy; greedy one policy = greedy_policy.GreedyPolicy(q_policy) filtered_random_py_policy = FilteredRandomPyPolicy( time_step_spec=policy.time_step_spec, action_spec=policy.action_spec, ) filtered_random_tf_policy = tf_py_policy.TFPyPolicy( filtered_random_py_policy) collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy( q_policy, epsilon=tf_agent._epsilon_greedy) # Patch random policy for epsilon greedy collect policy filtered_random_tf_policy = FilteredRandomTFPolicy( time_step_spec=policy.time_step_spec, action_spec=policy.action_spec, ) collect_policy._random_policy = filtered_random_tf_policy tf_agent._policy = policy tf_agent._collect_policy = collect_policy tf_agent.initialize() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=train_env.batch_size, max_length=replay_buffer_capacity, ) print( 'Pre-filling replay buffer in {} steps'.format(initial_collect_steps)) for _ in range(initial_collect_steps): traj = collect_step(train_env, dummy_policy) replay_buffer.add_batch(traj) dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2, ).prefetch(3) iterator = iter(dataset) # Train tf_agent.train = common.function(tf_agent.train) tf_agent.train_step_counter.assign(0) avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes) returns = [avg_return] print('Starting iterations...') for i in range(num_iterations): # fill replay buffer for _ in range(collect_steps_per_iteration): traj = collect_step(train_env, tf_agent.collect_policy) # Add trajectory to the replay buffer replay_buffer.add_batch(traj) experience, _ = next(iterator) train_loss = tf_agent.train(experience) step = tf_agent.train_step_counter.numpy() if step % log_interval == 0: print('step = {0}: loss = {1}'.format(step, train_loss.loss)) if step % eval_interval == 0: avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes) print('step = {0}: avg return = {1}'.format(step, avg_return)) returns.append(avg_return) print('Finished {} iterations!'.format(num_iterations)) print('Playing with resulting policy') VERBOSE = True r = compute_avg_return(eval_env, tf_agent.policy, 1) print('Result: {}'.format(r)) steps = range(0, num_iterations + 1, eval_interval) # merged = tf.summary.merge_all() # writer = tf.summary.FileWriter(FLAGS.log_dir) # # writer.close() print('Check out chart for learning') plt.plot(steps, returns) plt.ylabel('Average Return') plt.xlabel('Step') plt.ylim(top=1000) plt.show()
def train_eval( root_dir, env_name='MaskedCartPole-v0', num_iterations=100000, input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), train_sequence_length=10, # Params for collect initial_collect_steps=50, collect_episodes_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=10, batch_size=128, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=100, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): eval_py_env = suite_gym.load(env_name) tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) q_net = q_rnn_network.QRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps).run() collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run() # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) experience, _ = iterator.get_next() loss_info = common.function(tf_agent.train)(experience=experience) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) logging.info('Collecting initial experience.') sess.run(initial_collect_op) # Compute evaluation metrics. global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([loss_info, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): # Train/collect/eval. start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, )
def testSaveGetInitialState(self): network = q_rnn_network.QRnnNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec, lstm_size=(40,)) policy = q_policy.QPolicy( time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) train_step = common.create_variable('train_step', initial_value=0) saver_nobatch = policy_saver.PolicySaver( policy, train_step=train_step, batch_size=None, use_nest_path_signatures=False) path = os.path.join(self.get_temp_dir(), 'save_model_initial_state_nobatch') self.evaluate(tf.compat.v1.global_variables_initializer()) with self.cached_session(): saver_nobatch.save(path) reloaded_nobatch = tf.compat.v2.saved_model.load(path) self.evaluate( tf.compat.v1.initializers.variables(reloaded_nobatch.model_variables)) self.assertIn('get_initial_state', reloaded_nobatch.signatures) reloaded_get_initial_state = ( reloaded_nobatch.signatures['get_initial_state']) self._compare_input_output_specs( reloaded_get_initial_state, expected_input_specs=(tf.TensorSpec( dtype=tf.int32, shape=(), name='batch_size'),), expected_output_spec=policy.policy_state_spec, batch_input=False, batch_size=None) initial_state = policy.get_initial_state(batch_size=3) initial_state = self.evaluate(initial_state) reloaded_nobatch_initial_state = reloaded_nobatch.get_initial_state( batch_size=3) reloaded_nobatch_initial_state = self.evaluate( reloaded_nobatch_initial_state) tf.nest.map_structure(self.assertAllClose, initial_state, reloaded_nobatch_initial_state) saver_batch = policy_saver.PolicySaver( policy, train_step=train_step, batch_size=3, use_nest_path_signatures=False) path = os.path.join(self.get_temp_dir(), 'save_model_initial_state_batch') with self.cached_session(): saver_batch.save(path) reloaded_batch = tf.compat.v2.saved_model.load(path) self.evaluate( tf.compat.v1.initializers.variables(reloaded_batch.model_variables)) self.assertIn('get_initial_state', reloaded_batch.signatures) reloaded_get_initial_state = reloaded_batch.signatures['get_initial_state'] self._compare_input_output_specs( reloaded_get_initial_state, expected_input_specs=(), expected_output_spec=policy.policy_state_spec, batch_input=False, batch_size=3) reloaded_batch_initial_state = reloaded_batch.get_initial_state() reloaded_batch_initial_state = self.evaluate(reloaded_batch_initial_state) tf.nest.map_structure(self.assertAllClose, initial_state, reloaded_batch_initial_state)