def train_eval( root_dir, tf_master='', env_name='HalfCheetah-v2', env_load_fn=suite_mujoco.load, random_seed=0, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, # Params for collect num_environment_steps=10000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-4, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=50, rb_checkpoint_interval=200, log_interval=50, summary_interval=50, summaries_flush_secs=1, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ batched_py_metric.BatchedPyMetric( AverageReturnMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), batched_py_metric.BatchedPyMetric( AverageEpisodeLengthMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), ] eval_summary_writer_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf.compat.v1.set_random_seed(random_seed) eval_py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) environment_steps_metric = tf_metrics.EnvironmentSteps() environment_steps_count = environment_steps_metric.result() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Add to replay buffer and other agent specific observers. replay_buffer_observer = [replay_buffer.add_batch] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_buffer_observer + train_metrics, num_episodes=collect_episodes_per_iteration).run() trajectories = replay_buffer.gather_all() train_op, _ = tf_agent.train(experience=trajectories) with tf.control_dependencies([train_op]): clear_replay_op = replay_buffer.clear() with tf.control_dependencies([clear_replay_op]): train_op = tf.identity(train_op) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(step_metrics=step_metrics) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session(tf_master) as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) collect_time = 0 train_time = 0 timed_at_step = sess.run(global_step) steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) while sess.run(environment_steps_count) < num_environment_steps: global_step_val = sess.run(global_step) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op) start_time = time.time() sess.run(collect_op) collect_time += time.time() - start_time start_time = time.time() total_loss = sess.run(train_op) train_time += time.time() - start_time global_step_val = sess.run(global_step) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) # One final eval before exiting. metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op)
def train_eval( root_dir, env_name='HalfCheetah-v1', env_load_fn=suite_mujoco.load, num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400,), critic_action_fc_layers=None, critic_joint_fc_layers=(300,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, num_parallel_environments=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # TODO(kbanoop): Figure out if it is possible to avoid the with block. with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): if num_parallel_environments > 1: tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) else: tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_py_env = env_load_fn(env_name) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec(), batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy()) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] global_step = tf.train.get_or_create_global_step() collect_policy = tf_agent.collect_policy() initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch], num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = dataset.make_initializable_iterator() trajectories, unused_info = iterator.get_next() train_op = tf_agent.train( experience=trajectories, train_step_counter=global_step) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=tf.contrib.checkpoint.List(train_metrics)) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy(), global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) summary_op = tf.contrib.summary.all_summary_ops() with eval_summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) tf.contrib.summary.initialize(session=sess) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_op, global_step]) timed_at_step = sess.run(global_step) time_acc = 0 steps_per_second_ph = tf.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() time_acc += time.time() - start_time if global_step_val % log_interval == 0: tf.logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc tf.logging.info('%.3f steps/sec' % steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=1000, # TODO(kbanoop): rename to policy_fc_layers. actor_fc_layers=(100, ), # Params for collect collect_episodes_per_iteration=2, replay_buffer_capacity=2000, # Params for train learning_rate=1e-3, gradient_clipping=None, normalize_returns=True, # Params for eval num_eval_episodes=10, eval_interval=100, # Params for checkpoints, summaries, and logging train_checkpoint_interval=100, policy_checkpoint_interval=100, rb_checkpoint_interval=200, log_interval=100, summary_interval=100, summaries_flush_secs=1, debug_summaries=True, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for Reinforce.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): eval_py_env = suite_gym.load(env_name) tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) # TODO(kbanoop): Handle distributions without gin. actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers) tf_agent = reinforce_agent.ReinforceAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), normalize_returns=normalize_returns, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run() experience = replay_buffer.gather_all() train_op = tf_agent.train(experience) clear_rb_op = replay_buffer.clear() train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # TODO(sguada) Remove once Periodically can be saved. common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Compute evaluation metrics. global_step_call = sess.make_callable(global_step) global_step_val = global_step_call() metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable(train_op) clear_rb_call = sess.make_callable(clear_rb_op) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() total_loss = train_step_call() clear_rb_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def train_eval( root_dir, env_name='HalfCheetah-v1', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=100, eval_interval=500000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=10000, summary_interval=10000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_py_env = suite_mujoco.load(env_name) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] collect_policy = tf_agent.collect_policy initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=5 * batch_size, num_steps=2).apply(tf.data.experimental.unbatch()).filter( _filter_invalid_transition).batch(batch_size).prefetch( batch_size * 5) dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) policy_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common_utils.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) global_step_val = sess.run(global_step) if global_step_val == 0: # Initial eval of randomly initialized policy metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable(train_op) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): total_loss = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, log_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # Note this is a python environment. env = batched_py_environment.BatchedPyEnvironment( [suite_gym.load(env_name)]) eval_py_env = suite_gym.load(env_name) # Convert specs to BoundedTensorSpec. action_spec = tensor_spec.from_spec(env.action_spec()) observation_spec = tensor_spec.from_spec(env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()), tensor_spec.from_spec(env.action_spec()), fc_layer_params=fc_layer_params) # The agent must be in graph. global_step = tf.compat.v1.train.get_or_create_global_step() agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_collect_policy = agent.collect_policy collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy) greedy_policy = py_tf_policy.PyTFPolicy(agent.policy) random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec()) # Python replay buffer. replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer( capacity=replay_buffer_capacity, data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec)) time_step = env.reset() # Initialize the replay buffer with some transitions. We use the random # policy to initialize the replay buffer to make sure we get a good # distribution of actions. for _ in range(initial_collect_steps): time_step = collect_step(env, time_step, random_policy, replay_buffer) # TODO(b/112041045) Use global_step as counter. train_checkpointer = common.Checkpointer(ckpt_dir=train_dir, agent=agent, global_step=global_step) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=agent.policy, global_step=global_step) ds = replay_buffer.as_dataset(sample_batch_size=batch_size, num_steps=2) ds = ds.prefetch(4) itr = tf.compat.v1.data.make_initializable_iterator(ds) experience = itr.get_next() train_op = common.function(agent.train)(experience) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() with tf.compat.v1.Session() as session: train_checkpointer.initialize_or_restore(session) common.initialize_uninitialized_variables(session) session.run(itr.initializer) # Copy critic network values to the target critic network. session.run(agent.initialize()) train = session.make_callable(train_op) global_step_call = session.make_callable(global_step) session.run(train_summary_writer.init()) session.run(eval_summary_writer.init()) # Compute inital evaluation metrics. global_step_val = global_step_call() metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) timed_at_step = global_step_val collect_time = 0 train_time = 0 steps_per_second_ph = tf.compat.v1.placeholder(tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() for _ in range(collect_steps_per_iteration): time_step = collect_step(env, time_step, collect_policy, replay_buffer) collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss = train() train_time += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss.loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) session.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info('%.3f steps/sec', steps_per_sec) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) # Reset timing to avoid counting eval time. timed_at_step = global_step_val start_time = time.time()
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_whitelist='position', num_iterations=100000, actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), # Params for collect initial_collect_steps=1, collect_episodes_per_iteration=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=200, batch_size=64, train_sequence_length=10, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=10000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_whitelist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_whitelist=[observations_whitelist]) ] else: env_wrappers = [] environment = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(environment) eval_py_env = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) actor_net = actor_rnn_network.ActorRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_rnn_network.CriticRnnNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps).run() collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run() # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = iterator.get_next() train_fn = common.function(tf_agent.train) train_op = train_fn(experience=trajectories, train_step_counter=global_step) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) summary_op = tf.contrib.summary.all_summary_ops() with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(b/126239733) Remove once Periodically can be saved. common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) tf.contrib.summary.initialize(session=sess) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_op]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() global_step_val = global_step_call() time_acc += time.time() - start_time if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True)
def train_eval( root_dir, env_name='MaskedCartPole-v0', num_iterations=100000, input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), train_sequence_length=10, # Params for collect initial_collect_steps=50, collect_episodes_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=10, batch_size=128, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=100, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): eval_py_env = suite_gym.load(env_name) tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) q_net = q_rnn_network.QRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps).run() collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run() # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) experience, _ = iterator.get_next() loss_info = common.function(tf_agent.train)(experience=experience) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) logging.info('Collecting initial experience.') sess.run(initial_collect_op) # Compute evaluation metrics. global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([loss_info, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): # Train/collect/eval. start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, )
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, agent_class=dqn_agent.DqnAgent, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_py_env = suite_gym.load(env_name) q_net = q_network.QNetwork(tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=fc_layer_params) tf_agent = agent_class( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate), # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839 epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec(), batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy()) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] global_step = tf.train.get_or_create_global_step() replay_observer = [replay_buffer.add_batch] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps).run() collect_policy = tf_agent.collect_policy() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = dataset.make_initializable_iterator() trajectories, _ = iterator.get_next() train_op = tf_agent.train(experience=trajectories, train_step_counter=global_step) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=tf.contrib.checkpoint.List(train_metrics)) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy(), global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) summary_op = tf.contrib.summary.all_summary_ops() with eval_summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) tf.contrib.summary.initialize(session=sess) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable( [train_op, summary_op, global_step]) timed_at_step = sess.run(global_step) collect_time = 0 train_time = 0 steps_per_second_ph = tf.placeholder(tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): # Train/collect/eval. start_time = time.time() collect_call() collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() train_time += time.time() - start_time if global_step_val % log_interval == 0: tf.logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) tf.logging.info('%.3f steps/sec' % steps_per_sec) tf.logging.info( 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def train(): summary_interval = 1000 summaries_flush_secs = 10 num_eval_episodes = 5 root_dir = '/tmp/tensorflow/logs/tfenv01' train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs*1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs*1000) # maybe py_metrics? eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] environment = TradeEnvironment() # utils.validate_py_environment(environment, episodes=5) # Environments global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): train_env = tf_py_environment.TFPyEnvironment(environment) eval_env = tf_py_environment.TFPyEnvironment(environment) num_iterations = 50 fc_layer_params = (512, ) # ~ (17 + 1001) / 2 input_fc_layer_params = (50, ) output_fc_layer_params = (20, ) lstm_size = (30, ) initial_collect_steps = 20 collect_steps_per_iteration = 1 collect_episodes_per_iteration = 1 # the same as above batch_size = 64 replay_buffer_capacity = 10000 train_sequence_length = 10 gamma = 0.99 # check if 1.0 works as well target_update_tau = 0.05 target_update_period = 5 epsilon_greedy = 0.1 gradient_clipping = None reward_scale_factor = 1.0 learning_rate = 1e-2 log_interval = 30 eval_interval = 15 # train_env.observation_spec(), q_net = q_rnn_network.QRnnNetwork( train_env.time_step_spec().observation, train_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params, ) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) tf_agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=global_step, ) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=train_env.batch_size, max_length=replay_buffer_capacity, ) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Policy which does not allow some actions in certain states q_policy = FilteredQPolicy( tf_agent._time_step_spec, tf_agent._action_spec, q_network=tf_agent._q_network, ) # Valid policy to pre-fill replay buffer initial_collect_policy = DummyTradePolicy( train_env.time_step_spec(), train_env.action_spec(), ) print('Initial collecting...') initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( train_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps, ).run() # Main agent's policy; greedy one policy = greedy_policy.GreedyPolicy(q_policy) # Policy used for evaluation, the same as above eval_policy = greedy_policy.GreedyPolicy(q_policy) tf_agent._policy = policy collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy( q_policy, epsilon=tf_agent._epsilon_greedy) # Patch random policy for epsilon greedy collect policy filtered_random_tf_policy = FilteredRandomTFPolicy( time_step_spec=policy.time_step_spec, action_spec=policy.action_spec, ) collect_policy._random_policy = filtered_random_tf_policy tf_agent._collect_policy = collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( train_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration, ).run() dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length+1, ).prefetch(3) iterator = iter(dataset) experience, _ = next(iterator) loss_info = common.function(tf_agent.train)(experience=experience) # Checkpoints train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step, ) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer, ) summary_ops = [] for train_metric in train_metrics: summary_ops.append(train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2], )) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # sess.run(train_summary_writer.init()) # sess.run(eval_summary_writer.init()) # Initialize the graph # tfe.Saver().restore() # train_checkpointer.initialize_or_restore() # rb_checkpointer.initialize_or_restore() # sess.run(iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) print('Collecting initial experience...') sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([loss_info, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step, ) # Train for i in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_inerval == 0: print('step=%d, loss=%f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val-timed_at_step) / time_acc print('%.3f steps/sec', steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}, ) timed_at_step = global_step_val time_acc = 0 # Save checkpoints if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) # Evaluate if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) print('Done!')