def testTrainWithRnn(self): actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( self._obs_spec, self._action_spec, input_fc_layer_params=None, output_fc_layer_params=None, conv_layer_params=None, lstm_size=(40,), ) critic_net = critic_rnn_network.CriticRnnNetwork( (self._obs_spec, self._action_spec), observation_fc_layer_params=(16,), action_fc_layer_params=(16,), joint_fc_layer_params=(16,), lstm_size=(16,), output_fc_layer_params=None, ) counter = common.create_variable('test_train_counter') optimizer_fn = tf.compat.v1.train.AdamOptimizer agent = sac_agent.SacAgent( self._time_step_spec, self._action_spec, critic_network=critic_net, actor_network=actor_net, actor_optimizer=optimizer_fn(1e-3), critic_optimizer=optimizer_fn(1e-3), alpha_optimizer=optimizer_fn(1e-3), train_step_counter=counter, ) batch_size = 5 observations = tf.constant( [[[1, 2], [3, 4], [5, 6]]] * batch_size, dtype=tf.float32) actions = tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32) time_steps = ts.TimeStep( step_type=tf.constant([[1] * 3] * batch_size, dtype=tf.int32), reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), observation=[observations]) experience = trajectory.Trajectory( time_steps.step_type, [observations], actions, (), time_steps.step_type, time_steps.reward, time_steps.discount) # Force variable creation. agent.policy.variables() if tf.executing_eagerly(): loss = lambda: agent.train(experience) else: loss = agent.train(experience) self.evaluate(tf.compat.v1.initialize_all_variables()) self.assertEqual(self.evaluate(counter), 0) self.evaluate(loss) self.assertEqual(self.evaluate(counter), 1)
def testInit(self, lstm_size, rnn_construction_fn): observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0, 1) action_spec = tensor_spec.BoundedTensorSpec((2, ), tf.float32, 2, 3) network_input_spec = (observation_spec, action_spec) critic_rnn_network.CriticRnnNetwork( network_input_spec, lstm_size=lstm_size, rnn_construction_fn=rnn_construction_fn, rnn_construction_kwargs={'lstm_size': 3})
def testInit(self): observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0, 1) action_spec = tensor_spec.BoundedTensorSpec((2, ), tf.float32, 2, 3) network_input_spec = (observation_spec, action_spec) critic_rnn_network.CriticRnnNetwork(network_input_spec)
def testBuilds(self): observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0, 1) time_step_spec = ts.time_step_spec(observation_spec) time_step = tensor_spec.sample_spec_nest(time_step_spec, outer_dims=(1, )) action_spec = tensor_spec.BoundedTensorSpec((2, ), tf.float32, 2, 3) network_input_spec = (observation_spec, action_spec) action = tensor_spec.sample_spec_nest(action_spec, outer_dims=(1, )) net = critic_rnn_network.CriticRnnNetwork( network_input_spec, observation_conv_layer_params=[(4, 2, 2)], observation_fc_layer_params=(4, ), action_fc_layer_params=(5, ), joint_fc_layer_params=(5, ), lstm_size=(3, ), output_fc_layer_params=(5, ), ) network_input = (time_step.observation, action) q_values, network_state = net(network_input, time_step.step_type, net.get_initial_state(batch_size=1)) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertEqual([1], q_values.shape.as_list()) self.assertEqual(15, len(net.variables)) # Obs Conv Net Kernel self.assertEqual((2, 2, 3, 4), net.variables[0].shape) # Obs Conv Net bias self.assertEqual((4, ), net.variables[1].shape) # Obs Fc Kernel self.assertEqual((64, 4), net.variables[2].shape) # Obs Fc Bias self.assertEqual((4, ), net.variables[3].shape) # Action Fc Kernel self.assertEqual((2, 5), net.variables[4].shape) # Action Fc Bias self.assertEqual((5, ), net.variables[5].shape) # Joint Fc Kernel self.assertEqual((9, 5), net.variables[6].shape) # Joint Fc Bias self.assertEqual((5, ), net.variables[7].shape) # LSTM Cell Kernel self.assertEqual((5, 12), net.variables[8].shape) # LSTM Cell Recurrent Kernel self.assertEqual((3, 12), net.variables[9].shape) # LSTM Cell Bias self.assertEqual((12, ), net.variables[10].shape) # Output Fc Kernel self.assertEqual((3, 5), net.variables[11].shape) # Output Fc Bias self.assertEqual((5, ), net.variables[12].shape) # Q Value Kernel self.assertEqual((5, 1), net.variables[13].shape) # Q Value Bias self.assertEqual((1, ), net.variables[14].shape) # Assert LSTM cell is created. self.assertEqual((1, 3), network_state[0].shape) self.assertEqual((1, 3), network_state[1].shape)
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_whitelist='position', num_iterations=100000, actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), # Params for collect initial_collect_steps=1, collect_episodes_per_iteration=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=200, batch_size=64, train_sequence_length=10, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, gamma=0.995, reward_scale_factor=1.0, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=10000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # TODO(kbanoop): Figure out if it is possible to avoid the with block. with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): if observations_whitelist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_whitelist=[observations_whitelist]) ] else: env_wrappers = [] environment = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(environment) eval_py_env = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) actor_net = actor_rnn_network.ActorRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_net = critic_rnn_network.CriticRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, ) tf_agent = td3_agent.Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, gamma=gamma, reward_scale_factor=reward_scale_factor, debug_summaries=debug_summaries) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec(), batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy()) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] global_step = tf.train.get_or_create_global_step() # TODO(oars): Refactor drivers to better handle policy states. Remove the # policy reset and passing down an empyt policy state to the driver. collect_policy = tf_agent.collect_policy() policy_state = collect_policy.get_initial_state(tf_env.batch_size) initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch], num_episodes=initial_collect_steps).run(policy_state=policy_state) policy_state = collect_policy.get_initial_state(tf_env.batch_size) collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run( policy_state=policy_state) # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = dataset.make_initializable_iterator() trajectories, unused_info = iterator.get_next() train_op = tf_agent.train(experience=trajectories, train_step_counter=global_step) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=tf.contrib.checkpoint.List(train_metrics)) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy(), global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) summary_op = tf.contrib.summary.all_summary_ops() with eval_summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) tf.contrib.summary.initialize(session=sess) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable( [train_op, summary_op, global_step]) timed_at_step = sess.run(global_step) time_acc = 0 steps_per_second_ph = tf.placeholder(tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() time_acc += time.time() - start_time if global_step_val % log_interval == 0: tf.logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc tf.logging.info('%.3f steps/sec' % steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_whitelist='position', num_iterations=100000, actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), # Params for collect initial_collect_steps=1, collect_episodes_per_iteration=1, replay_buffer_capacity=100000, exploration_noise_std=0.1, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=200, batch_size=64, actor_update_period=2, train_sequence_length=10, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, gamma=0.995, reward_scale_factor=1.0, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=10000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_whitelist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_whitelist=[observations_whitelist]) ] else: env_wrappers = [] environment = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(environment) eval_py_env = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) actor_net = actor_rnn_network.ActorRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_rnn_network.CriticRnnNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, ) global_step = tf.compat.v1.train.get_or_create_global_step() tf_agent = td3_agent.Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, gamma=gamma, reward_scale_factor=reward_scale_factor, debug_summaries=debug_summaries, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy policy_state = collect_policy.get_initial_state(tf_env.batch_size) initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps).run(policy_state=policy_state) policy_state = collect_policy.get_initial_state(tf_env.batch_size) collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run( policy_state=policy_state) # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = iterator.get_next() train_fn = common.function(tf_agent.train) train_op = train_fn(experience=trajectories) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, )
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_whitelist='position', num_iterations=100000, actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), # Params for collect initial_collect_episodes=1, collect_episodes_per_iteration=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train # Params for train train_steps_per_iteration=200, batch_size=64, train_sequence_length=10, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=None, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=True, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_whitelist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_whitelist=[observations_whitelist]) ] else: env_wrappers = [] tf_env = tf_py_environment.TFPyEnvironment( suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers)) actor_net = actor_rnn_network.ActorRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_rnn_network.CriticRnnNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d episodes ' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience, train_step_counter=global_step) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_allowlist='position', eval_env_name=None, num_iterations=1000000, # Params for networks. actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), num_parallel_environments=1, # Params for collect initial_collect_episodes=1, collect_episodes_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=256, critic_learning_rate=3e-4, train_sequence_length=20, actor_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for RNN SAC on DM control.""" root_dir = os.path.expanduser(root_dir) summary_writer = tf.compat.v2.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_allowlist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_allowlist=[observations_allowlist]) ] else: env_wrappers = [] env_load_fn = functools.partial(suite_dm_control.load, task_name=task_name, env_wrappers=env_wrappers) if num_parallel_environments == 1: py_env = env_load_fn(env_name) else: py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(eval_env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_rnn_network.CriticRnnNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] env_steps = tf_metrics.EnvironmentSteps(prefix='Train') average_return = tf_metrics.AverageReturnMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(prefix='Train'), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if env_steps.result() == 0 or replay_buffer.num_frames() == 0: logging.info( 'Initializing replay buffer by collecting experience for %d episodes ' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.result()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) time_acc = 0 env_steps_before = env_steps.result().numpy() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): # Reduce filter_fn over full trajectory sampled. The sequence is kept only # if all elements except for the last one pass the filter. This is to # allow training on terminal steps. return tf.reduce_all(~trajectories.is_boundary()[:-1]) dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=train_sequence_length + 1).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() start_env_steps = env_steps.result() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) episode_steps = env_steps.result() - start_env_steps # TODO(b/152648849) for _ in range(episode_steps): for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.compat.v2.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def testTrainWithRnn(self, cql_alpha, num_cql_samples, include_critic_entropy_term, use_lagrange_cql_alpha, expected_loss): actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( self._obs_spec, self._action_spec, input_fc_layer_params=None, output_fc_layer_params=None, conv_layer_params=None, lstm_size=(40, ), ) critic_net = critic_rnn_network.CriticRnnNetwork( (self._obs_spec, self._action_spec), observation_fc_layer_params=(16, ), action_fc_layer_params=(16, ), joint_fc_layer_params=(16, ), lstm_size=(16, ), output_fc_layer_params=None, ) counter = common.create_variable('test_train_counter') optimizer_fn = tf.compat.v1.train.AdamOptimizer agent = cql_sac_agent.CqlSacAgent( self._time_step_spec, self._action_spec, critic_network=critic_net, actor_network=actor_net, actor_optimizer=optimizer_fn(1e-3), critic_optimizer=optimizer_fn(1e-3), alpha_optimizer=optimizer_fn(1e-3), cql_alpha=cql_alpha, num_cql_samples=num_cql_samples, include_critic_entropy_term=include_critic_entropy_term, use_lagrange_cql_alpha=use_lagrange_cql_alpha, random_seed=self._random_seed, train_step_counter=counter, ) batch_size = 5 observations = tf.constant([[[1, 2], [3, 4], [5, 6]]] * batch_size, dtype=tf.float32) actions = tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32) time_steps = ts.TimeStep(step_type=tf.constant([[1] * 3] * batch_size, dtype=tf.int32), reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), observation=observations) experience = trajectory.Trajectory(time_steps.step_type, observations, actions, (), time_steps.step_type, time_steps.reward, time_steps.discount) # Force variable creation. agent.policy.variables() if not tf.executing_eagerly(): # Get experience first to make sure optimizer variables are created and # can be initialized. experience = agent.train(experience) with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertEqual(self.evaluate(counter), 0) self.evaluate(experience) self.assertEqual(self.evaluate(counter), 1) else: self.assertEqual(self.evaluate(counter), 0) loss = self.evaluate(agent.train(experience)) self.assertAllClose(loss.loss, expected_loss) self.assertEqual(self.evaluate(counter), 1)
def train_eval(): """A simple train and eval for DDPG.""" logdir = FLAGS.logdir train_dir = os.path.join(logdir, 'train') eval_dir = os.path.join(logdir, 'eval') summary_flush_millis = 10 * 1000 train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summary_flush_millis) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summary_flush_millis) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=FLAGS.num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( buffer_size=FLAGS.num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % 1000, 0)): env = FLAGS.environment tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env)) eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env)) if FLAGS.use_rnn: actor_net = actor_rnn_network.ActorRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=parse_str_flag(FLAGS.actor_fc_layers), lstm_size=parse_str_flag(FLAGS.actor_lstm_sizes), output_fc_layer_params=FLAGS.actor_output_fc_layers) critic_net = critic_rnn_network.CriticRnnNetwork( input_tensor_spec=(tf_env.time_step_spec().observation, tf_env.action_spec()), observation_fc_layer_params=parse_str_flag( FLAGS.critic_obs_fc_layers), action_fc_layer_params=parse_str_flag( FLAGS.critic_action_fc_layers), joint_fc_layer_params=parse_str_flag( FLAGS.critic_joint_fc_layers), lstm_size=parse_str_flag(FLAGS.critic_lstm_sizes), output_fc_layer_params=parse_str_flag( FLAGS.critic_output_fc_layers)) else: actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=parse_str_flag(FLAGS.actor_fc_layers)) critic_net = critic_network.CriticNetwork( input_tensor_spec=(tf_env.time_step_spec().observation, tf_env.action_spec()), observation_fc_layer_params=parse_str_flag( FLAGS.critic_obs_fc_layers), action_fc_layer_params=parse_str_flag( FLAGS.critic_action_fc_layers), joint_fc_layer_params=parse_str_flag( FLAGS.critic_joint_fc_layers)) tf_agent = td3_agent.Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.train.AdamOptimizer( learning_rate=FLAGS.actor_learning_rate), critic_optimizer=tf.train.AdamOptimizer( learning_rate=FLAGS.critic_learning_rate), exploration_noise_std=FLAGS.exploration_noise_std, target_update_tau=FLAGS.target_update_tau, target_update_period=FLAGS.target_update_period, td_errors_loss_fn=None, #tf.compat.v1.losses.huber_loss, gamma=FLAGS.gamma, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=FLAGS.replay_buffer_size) initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch], num_steps=FLAGS.initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=FLAGS.collect_steps_per_iteration) initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=FLAGS.num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset to generate trajectories. dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=FLAGS.batch_size, num_steps=(FLAGS.train_sequence_length + 1)).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) train_step = common.function(train_step) for _ in range(FLAGS.num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(FLAGS.train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % 1000 == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % 10000 == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=FLAGS.num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) return train_loss
def train(): num_iterations=1000000 # Params for networks. actor_fc_layers=(128, 64) actor_output_fc_layers=(64,) actor_lstm_size=(32,) critic_obs_fc_layers=None critic_action_fc_layers=None critic_joint_fc_layers=(128,) critic_output_fc_layers=(64,) critic_lstm_size=(32,) num_parallel_environments=1 # Params for collect initial_collect_episodes=1 collect_episodes_per_iteration=1 replay_buffer_capacity=1000000 # Params for target update target_update_tau=0.05 target_update_period=5 # Params for train train_steps_per_iteration=1 batch_size=256 critic_learning_rate=3e-4 train_sequence_length=20 actor_learning_rate=3e-4 alpha_learning_rate=3e-4 td_errors_loss_fn=tf.math.squared_difference gamma=0.99 reward_scale_factor=0.1 gradient_clipping=None use_tf_functions=True # Params for eval num_eval_episodes=30 eval_interval=10000 debug_summaries=False summarize_grads_and_vars=False global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=tanh_normal_projection_network .TanhNormalProjectionNetwork) critic_net = critic_rnn_network.CriticRnnNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] env_steps = tf_metrics.EnvironmentSteps(prefix='Train') eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=replay_observer, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_observer, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if env_steps.result() == 0 or replay_buffer.num_frames() == 0: logging.info( 'Initializing replay buffer by collecting experience for %d episodes ' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) time_acc = 0 env_steps_before = env_steps.result().numpy() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): # Reduce filter_fn over full trajectory sampled. The sequence is kept only # if all elements except for the last one pass the filter. This is to # allow training on terminal steps. return tf.reduce_all(~trajectories.is_boundary()[:-1]) dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=train_sequence_length+1).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_env_steps = env_steps.result() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) episode_steps = env_steps.result() - start_env_steps # TODO(b/152648849) for _ in range(episode_steps): for _ in range(train_steps_per_iteration): train_step()
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=1, collect_steps_per_iteration=1, replay_buffer_capacity=1000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=0, eval_interval=1000000, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=200, rb_checkpoint_interval=300, log_interval=50, summary_interval=50, summaries_flush_secs=1000, debug_summaries=True, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" global interrupted dir_key = root_dir root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() # eval_summary_writer = tf.compat.v2.summary.create_file_writer( # eval_dir, flush_millis=summaries_flush_secs * 1000) # eval_metrics = [ # tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), # tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) # ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): train_env = T4TFEnv(metrics_key=dir_key) eval_env = T4TFEnv(fake=True) tf_env = tf_py_environment.TFPyEnvironment(train_env) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, # fc_layer_params=actor_fc_layers, conv_layer_params=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], lstm_size=(8, ), normal_projection_net=normal_projection_net) critic_net = critic_rnn_network.CriticRnnNetwork( (observation_spec, action_spec), observation_conv_layer_params=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], lstm_size=(8, ), # observation_fc_layer_params=critic_obs_fc_layers, # action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) print(actor_net) print(critic_net) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() # results = metric_utils.eager_compute( # eval_metrics, # eval_tf_env, # eval_policy, # num_episodes=num_eval_episodes, # train_step=global_step, # summary_writer=eval_summary_writer, # summary_prefix='Metrics', # ) # if eval_metrics_callback is not None: # eval_metrics_callback(results, global_step.numpy()) # metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=1, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): if interrupted: train_env.interrupted = True start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): if interrupted: train_env.interrupted = True experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: # actor_net.save(os.path.join(root_dir, 'actor'), save_format='tf') # critic_net.save(os.path.join(root_dir, 'critic'), save_format='tf') logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) obs_shape = time_step.observation.shape tf.compat.v2.summary.image( name='input_image', data=np.reshape(time_step.observation, (1, obs_shape[1], obs_shape[2], 3)), step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() print('global step: %d' % global_step_val) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss