def train_eval( root_dir, tf_master='', env_name='HalfCheetah-v2', env_load_fn=suite_mujoco.load, random_seed=0, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, # Params for collect num_environment_steps=10000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-4, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=50, rb_checkpoint_interval=200, log_interval=50, summary_interval=50, summaries_flush_secs=1, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ batched_py_metric.BatchedPyMetric( AverageReturnMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), batched_py_metric.BatchedPyMetric( AverageEpisodeLengthMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), ] eval_summary_writer_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf.compat.v1.set_random_seed(random_seed) eval_py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) environment_steps_metric = tf_metrics.EnvironmentSteps() environment_steps_count = environment_steps_metric.result() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Add to replay buffer and other agent specific observers. replay_buffer_observer = [replay_buffer.add_batch] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_buffer_observer + train_metrics, num_episodes=collect_episodes_per_iteration).run() trajectories = replay_buffer.gather_all() train_op, _ = tf_agent.train(experience=trajectories) with tf.control_dependencies([train_op]): clear_replay_op = replay_buffer.clear() with tf.control_dependencies([clear_replay_op]): train_op = tf.identity(train_op) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(step_metrics=step_metrics) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session(tf_master) as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) collect_time = 0 train_time = 0 timed_at_step = sess.run(global_step) steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) while sess.run(environment_steps_count) < num_environment_steps: global_step_val = sess.run(global_step) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op) start_time = time.time() sess.run(collect_op) collect_time += time.time() - start_time start_time = time.time() total_loss = sess.run(train_op) train_time += time.time() - start_time global_step_val = sess.run(global_step) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) # One final eval before exiting. metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op)
def train_eval( root_dir, env_name='HalfCheetah-v1', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=100, eval_interval=500000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=10000, summary_interval=10000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_py_env = suite_mujoco.load(env_name) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] collect_policy = tf_agent.collect_policy initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=5 * batch_size, num_steps=2).apply(tf.data.experimental.unbatch()).filter( _filter_invalid_transition).batch(batch_size).prefetch( batch_size * 5) dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) policy_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common_utils.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) global_step_val = sess.run(global_step) if global_step_val == 0: # Initial eval of randomly initialized policy metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable(train_op) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): total_loss = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='HalfCheetah-v1', env_load_fn=suite_mujoco.load, num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, num_parallel_environments=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if num_parallel_environments > 1: tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) else: tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_py_env = env_load_fn(env_name) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = iterator.get_next() train_fn = common_utils.function(tf_agent.train) train_op = train_fn(experience=trajectories) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) summary_op = tf.contrib.summary.all_summary_ops() with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(b/126239733) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) tf.contrib.summary.initialize(session=sess) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable( [train_op, summary_op, global_step]) timed_at_step = sess.run(global_step) time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() time_acc += time.time() - start_time if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, )
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, train_sequence_length=1, # Params for QNetwork fc_layer_params=(100, ), # Params for QRnnNetwork input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) if train_sequence_length > 1: q_net = q_rnn_network.QRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) else: q_net = q_network.QNetwork(tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=fc_layer_params) # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.contrib.summary.scalar(name='global_steps/sec', tensor=steps_per_sec) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=1000, # TODO(kbanoop): rename to policy_fc_layers. actor_fc_layers=(100, ), # Params for collect collect_episodes_per_iteration=2, replay_buffer_capacity=2000, # Params for train learning_rate=1e-3, gradient_clipping=None, normalize_returns=True, # Params for eval num_eval_episodes=10, eval_interval=100, # Params for checkpoints, summaries, and logging train_checkpoint_interval=100, policy_checkpoint_interval=100, rb_checkpoint_interval=200, log_interval=100, summary_interval=100, summaries_flush_secs=1, debug_summaries=True, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for Reinforce.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): eval_py_env = suite_gym.load(env_name) tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) # TODO(kbanoop): Handle distributions without gin. actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers) tf_agent = reinforce_agent.ReinforceAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), normalize_returns=normalize_returns, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run() experience = replay_buffer.gather_all() train_op = tf_agent.train(experience) clear_rb_op = replay_buffer.clear() train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # TODO(sguada) Remove once Periodically can be saved. common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Compute evaluation metrics. global_step_call = sess.make_callable(global_step) global_step_val = global_step_call() metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable(train_op) clear_rb_call = sess.make_callable(clear_rb_op) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() total_loss = train_step_call() clear_rb_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def __init__( self, root_dir, env_name, num_iterations=200, max_episode_frames=108000, # ALE frames terminal_on_life_loss=False, conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)), fc_layer_params=(512, ), # Params for collect initial_collect_steps=80000, # ALE frames epsilon_greedy=0.01, epsilon_decay_period=1000000, # ALE frames replay_buffer_capacity=1000000, # Params for train train_steps_per_iteration=1000000, # ALE frames update_period=16, # ALE frames target_update_tau=1.0, target_update_period=32000, # ALE frames batch_size=32, learning_rate=2.5e-4, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval do_eval=True, eval_steps_per_iteration=500000, # ALE frames eval_epsilon_greedy=0.001, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple Atari train and eval for DQN. Args: root_dir: Directory to write log files to. env_name: Fully-qualified name of the Atari environment (i.e. Pong-v0). num_iterations: Number of train/eval iterations to run. max_episode_frames: Maximum length of a single episode, in ALE frames. terminal_on_life_loss: Whether to simulate an episode termination when a life is lost. conv_layer_params: Params for convolutional layers of QNetwork. fc_layer_params: Params for fully connected layers of QNetwork. initial_collect_steps: Number of frames to ALE frames to process before beginning to train. Since this is in ALE frames, there will be initial_collect_steps/4 items in the RB when training starts. epsilon_greedy: Final epsilon value to decay to for training. epsilon_decay_period: Period over which to decay epsilon, from 1.0 to epsilon_greedy (defined above). replay_buffer_capacity: Maximum number of items to store in the RB. train_steps_per_iteration: Number of ALE frames to run through for each iteration of training. update_period: Run a train operation every update_period ALE frames. target_update_tau: Coeffecient for soft target network updates (1.0 == hard updates). target_update_period: Period, in ALE frames, to copy the live network to the target network. batch_size: Number of frames to include in each training batch. learning_rate: RMS optimizer learning rate. gamma: Discount for future rewards. reward_scale_factor: Scaling factor for rewards. gradient_clipping: Norm length to clip gradients. do_eval: If True, run an eval every iteration. If False, skip eval. eval_steps_per_iteration: Number of ALE frames to run through for each iteration of training. eval_epsilon_greedy: Epsilon value to use for the evaluation policy (0 == totally greedy policy). log_interval: Log stats to the terminal every log_interval training steps. summary_interval: Write TF summaries every summary_interval training steps. summaries_flush_secs: Flush summaries to disk every summaries_flush_secs seconds. debug_summaries: If True, write additional summaries for debugging (see dqn_agent for which summaries are written). summarize_grads_and_vars: Include gradients in summaries. eval_metrics_callback: A callback function that takes (metric_dict, global_step) as parameters. Called after every eval with the results of the evaluation. """ self._update_period = update_period / ATARI_FRAME_SKIP self._train_steps_per_iteration = (train_steps_per_iteration / ATARI_FRAME_SKIP) self._do_eval = do_eval self._eval_steps_per_iteration = eval_steps_per_iteration / ATARI_FRAME_SKIP self._eval_epsilon_greedy = eval_epsilon_greedy self._initial_collect_steps = initial_collect_steps / ATARI_FRAME_SKIP self._summary_interval = summary_interval self._num_iterations = num_iterations self._log_interval = log_interval self._eval_metrics_callback = eval_metrics_callback with gin.unlock_config(): gin.bind_parameter('AtariPreprocessing.terminal_on_life_loss', terminal_on_life_loss) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() self._train_summary_writer = train_summary_writer self._eval_summary_writer = None if self._do_eval: eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) self._eval_summary_writer = eval_summary_writer self._eval_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if(lambda: tf.math.equal( self._global_step % self._summary_interval, 0)): self._env = suite_atari.load( env_name, max_episode_steps=max_episode_frames / ATARI_FRAME_SKIP, gym_env_wrappers=suite_atari. DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) self._env = batched_py_environment.BatchedPyEnvironment( [self._env]) observation_spec = tensor_spec.from_spec( self._env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.from_spec(self._env.action_spec()) with tf.device('/cpu:0'): epsilon = tf.compat.v1.train.polynomial_decay( 1.0, self._global_step, epsilon_decay_period / ATARI_FRAME_SKIP / self._update_period, end_learning_rate=epsilon_greedy) with tf.device('/gpu:0'): optimizer = tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.0, epsilon=0.00001, centered=True) q_net = AtariQNetwork(observation_spec, action_spec, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params) tf_agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_net, optimizer=optimizer, epsilon_greedy=epsilon, target_update_tau=target_update_tau, target_update_period=(target_update_period / ATARI_FRAME_SKIP / self._update_period), td_errors_loss_fn=dqn_agent.element_wise_huber_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=self._global_step) self._collect_policy = py_tf_policy.PyTFPolicy( tf_agent.collect_policy) if self._do_eval: self._eval_policy = py_tf_policy.PyTFPolicy( epsilon_greedy_policy.EpsilonGreedyPolicy( policy=tf_agent.policy, epsilon=self._eval_epsilon_greedy)) py_observation_spec = self._env.observation_spec() py_time_step_spec = ts.time_step_spec(py_observation_spec) py_action_spec = policy_step.PolicyStep( self._env.action_spec()) data_spec = trajectory.from_transition(py_time_step_spec, py_action_spec, py_time_step_spec) self._replay_buffer = ( py_hashed_replay_buffer.PyHashedReplayBuffer( data_spec=data_spec, capacity=replay_buffer_capacity)) with tf.device('/cpu:0'): ds = self._replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(4) # TODO(b/123242430): Add prefetch_to_device back here in order to # improve performance once errors are resolved. self._ds_itr = tf.compat.v1.data.make_one_shot_iterator(ds) experience = self._ds_itr.get_next() with tf.device('/gpu:0'): self._train_op = tf_agent.train(experience) self._env_steps_metric = py_metrics.EnvironmentSteps() self._step_metrics = [ py_metrics.NumberOfEpisodes(), self._env_steps_metric, ] self._train_metrics = self._step_metrics + [ py_metrics.AverageReturnMetric(buffer_size=10), py_metrics.AverageEpisodeLengthMetric(buffer_size=10), ] # The _train_phase_metrics average over an entire train iteration, # rather than the rolling average of the last 10 episodes. self._train_phase_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._iteration_metric = py_metrics.CounterMetric( name='Iteration') # Summaries written from python should run every time they are # generated. with tf.compat.v2.summary.record_if(True): self._steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') self._steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=self._steps_per_second_ph) for metric in self._train_metrics: metric.tf_summaries(step_metrics=self._step_metrics) for metric in self._train_phase_metrics: metric.tf_summaries( step_metrics=(self._iteration_metric, )) self._iteration_metric.tf_summaries() if self._do_eval: with eval_summary_writer.as_default(): for metric in self._eval_metrics: metric.tf_summaries( step_metrics=(self._iteration_metric, )) self._train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=self._global_step, optimizer=optimizer, metrics=metric_utils.MetricsGroup( self._train_metrics + self._train_phase_metrics + [self._iteration_metric], 'train_metrics')) self._policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=self._global_step) self._rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=self._replay_buffer) self._init_agent_op = tf_agent.initialize()
def train_eval( root_dir, env_name='MaskedCartPole-v0', num_iterations=100000, input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), train_sequence_length=10, # Params for collect initial_collect_steps=50, collect_episodes_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=10, batch_size=128, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=100, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): eval_py_env = suite_gym.load(env_name) tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) q_net = q_rnn_network.QRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps).run() collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run() # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) experience, _ = iterator.get_next() loss_info = common.function(tf_agent.train)(experience=experience) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) logging.info('Collecting initial experience.') sess.run(initial_collect_op) # Compute evaluation metrics. global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([loss_info, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): # Train/collect/eval. start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, )
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_whitelist='position', num_iterations=100000, actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), # Params for collect initial_collect_steps=1, collect_episodes_per_iteration=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=200, batch_size=64, train_sequence_length=10, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, gamma=0.995, reward_scale_factor=1.0, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=10000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_whitelist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_whitelist=[observations_whitelist]) ] else: env_wrappers = [] environment = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(environment) eval_py_env = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) actor_net = actor_rnn_network.ActorRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_rnn_network.CriticRnnNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, ) tf_agent = td3_agent.Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, gamma=gamma, reward_scale_factor=reward_scale_factor, debug_summaries=debug_summaries, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # TODO(oars): Refactor drivers to better handle policy states. Remove the # policy reset and passing down an empyt policy state to the driver. collect_policy = tf_agent.collect_policy policy_state = collect_policy.get_initial_state(tf_env.batch_size) initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps).run(policy_state=policy_state) policy_state = collect_policy.get_initial_state(tf_env.batch_size) collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run( policy_state=policy_state) # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = iterator.get_next() train_op = tf_agent.train(experience=trajectories) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable(train_op) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, )
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=100, eval_interval=50000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def train(): summary_interval = 1000 summaries_flush_secs = 10 num_eval_episodes = 5 root_dir = '/tmp/tensorflow/logs/tfenv01' train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs*1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs*1000) # maybe py_metrics? eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] environment = TradeEnvironment() # utils.validate_py_environment(environment, episodes=5) # Environments global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): train_env = tf_py_environment.TFPyEnvironment(environment) eval_env = tf_py_environment.TFPyEnvironment(environment) num_iterations = 50 fc_layer_params = (512, ) # ~ (17 + 1001) / 2 input_fc_layer_params = (50, ) output_fc_layer_params = (20, ) lstm_size = (30, ) initial_collect_steps = 20 collect_steps_per_iteration = 1 collect_episodes_per_iteration = 1 # the same as above batch_size = 64 replay_buffer_capacity = 10000 train_sequence_length = 10 gamma = 0.99 # check if 1.0 works as well target_update_tau = 0.05 target_update_period = 5 epsilon_greedy = 0.1 gradient_clipping = None reward_scale_factor = 1.0 learning_rate = 1e-2 log_interval = 30 eval_interval = 15 # train_env.observation_spec(), q_net = q_rnn_network.QRnnNetwork( train_env.time_step_spec().observation, train_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params, ) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) tf_agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=global_step, ) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=train_env.batch_size, max_length=replay_buffer_capacity, ) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Policy which does not allow some actions in certain states q_policy = FilteredQPolicy( tf_agent._time_step_spec, tf_agent._action_spec, q_network=tf_agent._q_network, ) # Valid policy to pre-fill replay buffer initial_collect_policy = DummyTradePolicy( train_env.time_step_spec(), train_env.action_spec(), ) print('Initial collecting...') initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( train_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps, ).run() # Main agent's policy; greedy one policy = greedy_policy.GreedyPolicy(q_policy) # Policy used for evaluation, the same as above eval_policy = greedy_policy.GreedyPolicy(q_policy) tf_agent._policy = policy collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy( q_policy, epsilon=tf_agent._epsilon_greedy) # Patch random policy for epsilon greedy collect policy filtered_random_tf_policy = FilteredRandomTFPolicy( time_step_spec=policy.time_step_spec, action_spec=policy.action_spec, ) collect_policy._random_policy = filtered_random_tf_policy tf_agent._collect_policy = collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( train_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration, ).run() dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length+1, ).prefetch(3) iterator = iter(dataset) experience, _ = next(iterator) loss_info = common.function(tf_agent.train)(experience=experience) # Checkpoints train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step, ) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer, ) summary_ops = [] for train_metric in train_metrics: summary_ops.append(train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2], )) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # sess.run(train_summary_writer.init()) # sess.run(eval_summary_writer.init()) # Initialize the graph # tfe.Saver().restore() # train_checkpointer.initialize_or_restore() # rb_checkpointer.initialize_or_restore() # sess.run(iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) print('Collecting initial experience...') sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([loss_info, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step, ) # Train for i in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_inerval == 0: print('step=%d, loss=%f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val-timed_at_step) / time_acc print('%.3f steps/sec', steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}, ) timed_at_step = global_step_val time_acc = 0 # Save checkpoints if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) # Evaluate if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) print('Done!')