def testGinConfig(self): gin.parse_config_file( test_utils.test_src_dir_path( 'environments/configs/suite_mujoco.gin')) env = suite_mujoco.load() self.assertIsInstance(env, py_environment.Base) self.assertIsInstance(env, wrappers.TimeLimit)
def create_env(env_name): """Creates Environment.""" if env_name == 'Pendulum': env = gym.make('Pendulum-v0') elif env_name == 'Hopper': env = suite_mujoco.load('Hopper-v2') elif env_name == 'Walker2D': env = suite_mujoco.load('Walker2d-v2') elif env_name == 'HalfCheetah': env = suite_mujoco.load('HalfCheetah-v2') elif env_name == 'Ant': env = suite_mujoco.load('Ant-v2') elif env_name == 'Humanoid': env = suite_mujoco.load('Humanoid-v2') else: raise ValueError('Unsupported environment: %s' % env_name) return env
def main(_): logging.set_verbosity(logging.INFO) tf.enable_v2_behavior() collect( FLAGS.task, FLAGS.root_dir, replay_buffer_server_address=FLAGS.variable_container_server_address, variable_container_server_address=FLAGS.variable_container_server_address, create_env_fn=lambda: suite_mujoco.load('HalfCheetah-v2'))
def main(_): logging.set_verbosity(logging.INFO) tf.enable_v2_behavior() gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) strategy = strategy_utils.get_strategy(FLAGS.tpu, FLAGS.use_gpu) train( FLAGS.root_dir, strategy, replay_buffer_server_address=FLAGS.replay_buffer_server_address, variable_container_server_address=FLAGS.variable_container_server_address, create_agent_fn=_create_agent, create_env_fn=lambda: suite_mujoco.load('HalfCheetah-v2'), num_iterations=FLAGS.num_iterations, )
def main(): # environment eval_env = tf_py_environment.TFPyEnvironment( suite_mujoco.load('HalfCheetah-v2')) # deserialize saved policy saved_policy = tf.compat.v2.saved_model.load('checkpoints/policy_9500/') # apply_policy and visualize total_return = 0.0 for _ in range(10): episode_return = 0.0 status = eval_env.reset() policy_state = saved_policy.get_initial_state(eval_env.batch_size) while not status.is_last(): action = saved_policy.action(status, policy_state) status = eval_env.step(action.action) policy_state = action.state cv2.imshow('halfcheetah', eval_env.pyenv.envs[0].render()) cv2.waitKey(25) episode_return += status.reward total_return += episode_return avg_return = total_return / 10 print("average return is %f" % avg_return)
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) base_env = suite_mujoco.load(FLAGS.env_name) if hasattr(base_env, 'max_episode_steps'): max_episode_steps = base_env.max_episode_steps else: logging.info('Unknown max episode steps. Setting to 1000.') max_episode_steps = 1000 env = base_env.gym env = wrappers.check_and_normalize_box_actions(env) env.seed(FLAGS.seed) eval_env = suite_mujoco.load(FLAGS.env_name).gym eval_env = wrappers.check_and_normalize_box_actions(eval_env) eval_env.seed(FLAGS.seed + 1) spec = ( tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32, 'observation'), tensor_spec.TensorSpec([env.action_space.shape[0]], tf.float32, 'action'), tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32, 'next_observation'), tensor_spec.TensorSpec([1], tf.float32, 'reward'), tensor_spec.TensorSpec([1], tf.float32, 'mask'), ) init_spec = tensor_spec.TensorSpec([env.observation_space.shape[0]], tf.float32, 'observation') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=1, max_length=FLAGS.max_timesteps) init_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( init_spec, batch_size=1, max_length=FLAGS.max_timesteps) hparam_str_dict = dict(seed=FLAGS.seed, env=FLAGS.env_name) hparam_str = ','.join([ '%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys()) ]) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) rl_algo = algae.ALGAE(env.observation_space.shape[0], env.action_space.shape[0], FLAGS.log_interval, critic_lr=FLAGS.critic_lr, actor_lr=FLAGS.actor_lr, use_dqn=FLAGS.use_dqn, use_init_states=FLAGS.use_init_states, algae_alpha=FLAGS.algae_alpha, exponent=FLAGS.f_exponent) episode_return = 0 episode_timesteps = 0 done = True total_timesteps = 0 previous_time = time.time() replay_buffer_iter = iter( replay_buffer.as_dataset(sample_batch_size=FLAGS.sample_batch_size)) init_replay_buffer_iter = iter( init_replay_buffer.as_dataset( sample_batch_size=FLAGS.sample_batch_size)) log_dir = os.path.join(FLAGS.save_dir, 'logs') log_filename = os.path.join(log_dir, hparam_str) if not gfile.isdir(log_dir): gfile.mkdir(log_dir) eval_returns = [] with tqdm(total=FLAGS.max_timesteps, desc='') as pbar: # Final return is the average of the last 10 measurmenets. final_returns = collections.deque(maxlen=10) final_timesteps = 0 while total_timesteps < FLAGS.max_timesteps: _update_pbar_msg(pbar, total_timesteps) if done: if episode_timesteps > 0: current_time = time.time() train_measurements = [ ('train/returns', episode_return), ('train/FPS', episode_timesteps / (current_time - previous_time)), ] _write_measurements(summary_writer, train_measurements, total_timesteps) obs = env.reset() episode_return = 0 episode_timesteps = 0 previous_time = time.time() init_replay_buffer.add_batch(np.array([obs.astype(np.float32) ])) if total_timesteps < FLAGS.num_random_actions: action = env.action_space.sample() else: _, action, _ = rl_algo.actor(np.array([obs])) action = action[0].numpy() if total_timesteps >= FLAGS.start_training_timesteps: with summary_writer.as_default(): target_entropy = (-env.action_space.shape[0] if FLAGS.target_entropy is None else FLAGS.target_entropy) for _ in range(FLAGS.num_updates_per_env_step): rl_algo.train( replay_buffer_iter, init_replay_buffer_iter, discount=FLAGS.discount, tau=FLAGS.tau, target_entropy=target_entropy, actor_update_freq=FLAGS.actor_update_freq) next_obs, reward, done, _ = env.step(action) if (max_episode_steps is not None and episode_timesteps + 1 == max_episode_steps): done = True if not done or episode_timesteps + 1 == max_episode_steps: # pylint: disable=protected-access mask = 1.0 else: mask = 0.0 replay_buffer.add_batch((np.array([obs.astype(np.float32)]), np.array([action.astype(np.float32)]), np.array([next_obs.astype(np.float32)]), np.array([[reward]]).astype(np.float32), np.array([[mask]]).astype(np.float32))) episode_return += reward episode_timesteps += 1 total_timesteps += 1 pbar.update(1) obs = next_obs if total_timesteps % FLAGS.eval_interval == 0: logging.info('Performing policy eval.') average_returns, evaluation_timesteps = rl_algo.evaluate( eval_env, max_episode_steps=max_episode_steps) eval_returns.append(average_returns) fin = gfile.GFile(log_filename, 'w') np.save(fin, np.array(eval_returns)) fin.close() eval_measurements = [ ('eval/average returns', average_returns), ('eval/average episode length', evaluation_timesteps), ] # TODO(sandrafaust) Make this average of the last N. final_returns.append(average_returns) final_timesteps = evaluation_timesteps _write_measurements(summary_writer, eval_measurements, total_timesteps) logging.info('Eval: ave returns=%f, ave episode length=%f', average_returns, evaluation_timesteps) # Final measurement. final_measurements = [ ('final/average returns', sum(final_returns) / len(final_returns)), ('final/average episode length', final_timesteps), ] _write_measurements(summary_writer, final_measurements, total_timesteps)
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] eval_summary_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_py_env = suite_mujoco.load(env_name) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy( greedy_policy.GreedyPolicy(tf_agent.policy)) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] collect_policy = tf_agent.collect_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=5 * batch_size, num_steps=2).apply(tf.data.experimental.unbatch()).filter( _filter_invalid_transition).batch(batch_size).prefetch( batch_size * 5) dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) global_step_val = sess.run(global_step) if global_step_val == 0: # Initial eval of randomly initialized policy metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): total_loss, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=100000, exploration_noise_std=0.1, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_update_period=2, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for TD3.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_mujoco.load(env_name)) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = td3_agent.Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch], num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400,), critic_action_fc_layers=None, critic_joint_fc_layers=(300,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=100000, exploration_noise_std=0.1, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_update_period=2, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for TD3.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_py_env = suite_mujoco.load(env_name) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = td3_agent.Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration).run() dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = iterator.get_next() train_fn = common.function(tf_agent.train) train_op = train_fn(experience=trajectories) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append(train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(b/126239733): Remove once Periodically can be saved. common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops, global_step]) timed_at_step = sess.run(global_step) time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() time_acc += time.time() - start_time if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, )
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params num_iterations=1600, actor_fc_layers=(64, 64), value_fc_layers=(64, 64), learning_rate=3e-4, collect_sequence_length=2048, minibatch_size=64, num_epochs=10, # Agent params importance_ratio_clipping=0.2, lambda_value=0.95, discount_factor=0.99, entropy_regularization=0., value_pred_loss_coef=0.5, use_gae=True, use_td_lambda_return=True, gradient_clipping=0.5, value_clipping=None, # Replay params reverb_port=None, replay_capacity=10000, # Others policy_save_interval=5000, summary_interval=1000, eval_interval=10000, eval_episodes=100, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates PPO (Importance Ratio Clipping). Args: root_dir: Main directory path where checkpoints, saved_models, and summaries will be written to. env_name: Name for the Mujoco environment to load. num_iterations: The number of iterations to perform collection and training. actor_fc_layers: List of fully_connected parameters for the actor network, where each item is the number of units in the layer. value_fc_layers: : List of fully_connected parameters for the value network, where each item is the number of units in the layer. learning_rate: Learning rate used on the Adam optimizer. collect_sequence_length: Number of steps to take in each collect run. minibatch_size: Number of elements in each mini batch. If `None`, the entire collected sequence will be treated as one batch. num_epochs: Number of iterations to repeat over all collected data per data collection step. (Schulman,2017) sets this to 10 for Mujoco, 15 for Roboschool and 3 for Atari. importance_ratio_clipping: Epsilon in clipped, surrogate PPO objective. For more detail, see explanation at the top of the doc. lambda_value: Lambda parameter for TD-lambda computation. discount_factor: Discount factor for return computation. Default to `0.99` which is the value used for all environments from (Schulman, 2017). entropy_regularization: Coefficient for entropy regularization loss term. Default to `0.0` because no entropy bonus was used in (Schulman, 2017). value_pred_loss_coef: Multiplier for value prediction loss to balance with policy gradient loss. Default to `0.5`, which was used for all environments in the OpenAI baseline implementation. This parameters is irrelevant unless you are sharing part of actor_net and value_net. In that case, you would want to tune this coeeficient, whose value depends on the network architecture of your choice. use_gae: If True (default False), uses generalized advantage estimation for computing per-timestep advantage. Else, just subtracts value predictions from empirical return. use_td_lambda_return: If True (default False), uses td_lambda_return for training value function; here: `td_lambda_return = gae_advantage + value_predictions`. `use_gae` must be set to `True` as well to enable TD -lambda returns. If `use_td_lambda_return` is set to True while `use_gae` is False, the empirical return will be used and a warning will be logged. gradient_clipping: Norm length to clip gradients. value_clipping: Difference between new and old value predictions are clipped to this threshold. Value clipping could be helpful when training very deep networks. Default: no clipping. reverb_port: Port for reverb server, if None, use a randomly chosen unused port. replay_capacity: The maximum number of elements for the replay buffer. Items will be wasted if this is smalled than collect_sequence_length. policy_save_interval: How often, in train_steps, the policy will be saved. summary_interval: How often to write data into Tensorboard. eval_interval: How often to run evaluation, in train_steps. eval_episodes: Number of episodes to evaluate over. debug_summaries: Boolean for whether to gather debug summaries. summarize_grads_and_vars: If true, gradient summaries will be written. """ collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) num_environments = 1 observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) # TODO(b/172267869): Remove this conversion once TensorNormalizer stops # converting float64 inputs to float32. observation_tensor_spec = tf.TensorSpec( dtype=tf.float32, shape=observation_tensor_spec.shape) train_step = train_utils.create_train_step() actor_net_builder = ppo_actor_network.PPOActorNetwork() actor_net = actor_net_builder.create_sequential_actor_net( actor_fc_layers, action_tensor_spec) value_net = value_network.ValueNetwork( observation_tensor_spec, fc_layer_params=value_fc_layers, kernel_initializer=tf.keras.initializers.Orthogonal()) current_iteration = tf.Variable(0, dtype=tf.int64) def learning_rate_fn(): # Linearly decay the learning rate. return learning_rate * (1 - current_iteration / num_iterations) agent = ppo_clip_agent.PPOClipAgent( time_step_tensor_spec, action_tensor_spec, optimizer=tf.keras.optimizers.Adam( learning_rate=learning_rate_fn, epsilon=1e-5), actor_net=actor_net, value_net=value_net, importance_ratio_clipping=importance_ratio_clipping, lambda_value=lambda_value, discount_factor=discount_factor, entropy_regularization=entropy_regularization, value_pred_loss_coef=value_pred_loss_coef, # This is a legacy argument for the number of times we repeat the data # inside of the train function, incompatible with mini batch learning. # We set the epoch number from the replay buffer and tf.Data instead. num_epochs=1, use_gae=use_gae, use_td_lambda_return=use_td_lambda_return, gradient_clipping=gradient_clipping, value_clipping=value_clipping, # TODO(b/150244758): Default compute_value_and_advantage_in_train to False # after Reverb open source. compute_value_and_advantage_in_train=False, # Skips updating normalizers in the agent, as it's handled in the learner. update_normalizers_in_train=False, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() reverb_server = reverb.Server( [ reverb.Table( # Replay buffer storing experience for training. name='training_table', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_capacity, max_times_sampled=1, ), reverb.Table( # Replay buffer storing experience for normalization. name='normalization_table', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_capacity, max_times_sampled=1, ) ], port=reverb_port) # Create the replay buffer. reverb_replay_train = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=collect_sequence_length, table_name='training_table', server_address='localhost:{}'.format(reverb_server.port), # The only collected sequence is used to populate the batches. max_cycle_length=1, rate_limiter_timeout_ms=1000) reverb_replay_normalization = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=collect_sequence_length, table_name='normalization_table', server_address='localhost:{}'.format(reverb_server.port), # The only collected sequence is used to populate the batches. max_cycle_length=1, rate_limiter_timeout_ms=1000) rb_observer = reverb_utils.ReverbTrajectorySequenceObserver( reverb_replay_train.py_client, ['training_table', 'normalization_table'], sequence_length=collect_sequence_length, stride_length=collect_sequence_length) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) collect_env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={ triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric }), triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval), ] def training_dataset_fn(): return reverb_replay_train.as_dataset( sample_batch_size=num_environments, sequence_preprocess_fn=agent.preprocess_sequence) def normalization_dataset_fn(): return reverb_replay_normalization.as_dataset( sample_batch_size=num_environments, sequence_preprocess_fn=agent.preprocess_sequence) agent_learner = ppo_learner.PPOLearner( root_dir, train_step, agent, experience_dataset_fn=training_dataset_fn, normalization_dataset_fn=normalization_dataset_fn, num_samples=1, num_epochs=num_epochs, minibatch_size=minibatch_size, shuffle_buffer_size=collect_sequence_length, triggers=learning_triggers) tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=collect_sequence_length, observers=[rb_observer], metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric], reference_metrics=[collect_env_step_metric], summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), summary_interval=summary_interval) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( agent.policy, use_tf_function=True) if eval_interval: logging.info('Intial evaluation.') eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, metrics=actor.eval_metrics(eval_episodes), reference_metrics=[collect_env_step_metric], summary_dir=os.path.join(root_dir, 'eval'), episodes_per_run=eval_episodes) eval_actor.run_and_log() logging.info('Training on %s', env_name) last_eval_step = 0 for i in range(num_iterations): collect_actor.run() rb_observer.flush() agent_learner.run() reverb_replay_train.clear() reverb_replay_normalization.clear() current_iteration.assign_add(1) # Eval only if `eval_interval` has been set. Then, eval if the current train # step is equal or greater than the `last_eval_step` + `eval_interval` or if # this is the last iteration. This logic exists because agent_learner.run() # does not return after every train step. if (eval_interval and (agent_learner.train_step_numpy >= eval_interval + last_eval_step or i == num_iterations - 1)): logging.info('Evaluating.') eval_actor.run_and_log() last_eval_step = agent_learner.train_step_numpy rb_observer.close() reverb_server.stop()
def testActionSpec(self): env = suite_mujoco.load('HalfCheetah-v1') self.assertEqual(np.float32, env.observation_spec().dtype) self.assertEqual((17, ), env.observation_spec().shape)
def load_multiple_mugs_env( universe, action_mode, env_name=None, render_size=128, observation_render_size=64, observations_whitelist=None, action_repeat=1, num_train_tasks=30, num_eval_tasks=10, eval_on_holdout_tasks=True, return_multiple_tasks=False, model_input=None, auto_reset_task_each_episode=False, ): ### HARDCODED # temporary sanity assert env_name == 'SawyerShelfMT-v0' assert return_multiple_tasks assert universe == 'gym' # get eval and train tasks by loading a sample env sample_env = suite_mujoco.load(env_name) # train env train_tasks = sample_env.init_tasks(num_tasks=num_train_tasks, is_eval_env=False) # eval env eval_tasks = sample_env.init_tasks(num_tasks=num_eval_tasks, is_eval_env=eval_on_holdout_tasks) del sample_env print("train weights", train_tasks) print("eval weights", eval_tasks) if env_name == 'SawyerShelfMT-v0': from meld.environments.envs.shelf.assets.generate_sawyer_shelf_xml import generate_and_save_xml_file else: raise NotImplementedError train_xml_path = generate_and_save_xml_file(train_tasks, action_mode, is_eval=False) eval_xml_path = generate_and_save_xml_file(eval_tasks, action_mode, is_eval=True) ### train env # get wrappers wrappers = get_wrappers(device_id=0, model_input=model_input, render_size=render_size, observation_render_size=observation_render_size, observations_whitelist=observations_whitelist) # load env gym_kwargs = {"action_mode": action_mode, "xml_path": train_xml_path} py_env = suite_gym.load(env_name, gym_env_wrappers=wrappers, gym_kwargs=gym_kwargs) if action_repeat > 1: py_env = wrappers.ActionRepeat(py_env, action_repeat) ### eval env # get wrappers wrappers = get_wrappers(device_id=1, model_input=model_input, render_size=render_size, observation_render_size=observation_render_size, observations_whitelist=observations_whitelist) # load env gym_kwargs = {"action_mode": action_mode, "xml_path": eval_xml_path} eval_py_env = suite_gym.load(env_name, gym_env_wrappers=wrappers, gym_kwargs=gym_kwargs) eval_py_env = video_wrapper.VideoWrapper(eval_py_env) if action_repeat > 1: eval_py_env = wrappers.ActionRepeat(eval_py_env, action_repeat) py_env.assign_tasks(train_tasks) eval_py_env.assign_tasks(eval_tasks) # set task list and reset variable to true if auto_reset_task_each_episode: py_env.wrapped_env().set_auto_reset_task(train_tasks) eval_py_env.wrapped_env().set_auto_reset_task(eval_tasks) return py_env, eval_py_env, train_tasks, eval_tasks
def load_env(env_name, seed, action_repeat=0, frame_stack=1, obs_type='pixels'): """Loads a learning environment. Args: env_name: Name of the environment. seed: Random seed. action_repeat: (optional) action repeat multiplier. Useful for DM control suite tasks. frame_stack: (optional) frame stack. obs_type: `pixels` or `state` Returns: Learning environment. """ action_repeat_applied = False state_env = None if env_name.startswith('dm'): _, domain_name, task_name = env_name.split('-') if 'manipulation' in domain_name: env = manipulation.load(task_name) env = dm_control_wrapper.DmControlWrapper(env) else: env = _load_dm_env(domain_name, task_name, pixels=False, action_repeat=action_repeat) action_repeat_applied = True env = wrappers.FlattenObservationsWrapper(env) elif env_name.startswith('pixels-dm'): if 'distractor' in env_name: _, _, domain_name, task_name, _ = env_name.split('-') distractor = True else: _, _, domain_name, task_name = env_name.split('-') distractor = False # TODO(tompson): Are there DMC environments that have other # max_episode_steps? env = _load_dm_env(domain_name, task_name, pixels=True, action_repeat=action_repeat, max_episode_steps=1000, obs_type=obs_type, distractor=distractor) action_repeat_applied = True if obs_type == 'pixels': env = FlattenImageObservationsWrapper(env) state_env = None else: env = JointImageObservationsWrapper(env) state_env = tf_py_environment.TFPyEnvironment( wrappers.FlattenObservationsWrapper( _load_dm_env(domain_name, task_name, pixels=False, action_repeat=action_repeat))) else: env = suite_mujoco.load(env_name) env.seed(seed) if action_repeat > 1 and not action_repeat_applied: env = wrappers.ActionRepeat(env, action_repeat) if frame_stack > 1: env = FrameStackWrapperTfAgents(env, frame_stack) env = tf_py_environment.TFPyEnvironment(env) return env, state_env
def testMujocoEnvRegistered(self): env = suite_mujoco.load('HalfCheetah-v1') self.assertIsInstance(env, py_environment.Base) self.assertIsInstance(env, wrappers.TimeLimit)
def main(_): # environment serves as the dataset in reinforcement learning train_env = tf_py_environment.TFPyEnvironment( ParallelPyEnvironment([lambda: suite_mujoco.load('HalfCheetah-v2')] * batch_size)) eval_env = tf_py_environment.TFPyEnvironment( suite_mujoco.load('HalfCheetah-v2')) # create agent actor_net = ActorDistributionRnnNetwork(train_env.observation_spec(), train_env.action_spec(), lstm_size=(100, 100)) value_net = ValueRnnNetwork(train_env.observation_spec()) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3) tf_agent = ppo_agent.PPOAgent(train_env.time_step_spec(), train_env.action_spec(), optimizer=optimizer, actor_net=actor_net, value_net=value_net, normalize_observations=False, normalize_rewards=False, use_gae=True, num_epochs=25) tf_agent.initialize() # replay buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=train_env.batch_size, max_length=1000000) # policy saver saver = policy_saver.PolicySaver(tf_agent.policy) # define trajectory collector train_episode_count = tf_metrics.NumberOfEpisodes() train_total_steps = tf_metrics.EnvironmentSteps() train_avg_reward = tf_metrics.AverageReturnMetric( batch_size=train_env.batch_size) train_avg_episode_len = tf_metrics.AverageEpisodeLengthMetric( batch_size=train_env.batch_size) train_driver = dynamic_episode_driver.DynamicEpisodeDriver( train_env, tf_agent.collect_policy, # NOTE: use PPOPolicy to collect episode observers=[ replay_buffer.add_batch, train_episode_count, train_total_steps, train_avg_reward, train_avg_episode_len ], # callbacks when an episode is completely collected num_episodes=30, # how many episodes are collected in an iteration ) # training eval_avg_reward = tf_metrics.AverageReturnMetric(buffer_size=30) eval_avg_episode_len = tf_metrics.AverageEpisodeLengthMetric( buffer_size=30) while train_total_steps.result() < 25000000: train_driver.run() trajectories = replay_buffer.gather_all() loss, _ = tf_agent.train(experience=trajectories) replay_buffer.clear() # clear collected episodes right after training if tf_agent.train_step_counter.numpy() % 50 == 0: print('step = {0}: loss = {1}'.format( tf_agent.train_step_counter.numpy(), loss)) if tf_agent.train_step_counter.numpy() % 500 == 0: # save checkpoint saver.save('checkpoints/policy_%d' % tf_agent.train_step_counter.numpy()) # evaluate the updated policy eval_avg_reward.reset() eval_avg_episode_len.reset() eval_driver = dynamic_episode_driver.DynamicEpisodeDriver( eval_env, tf_agent.policy, observers=[ eval_avg_reward, eval_avg_episode_len, ], num_episodes= 30, # how many epsiodes are collected in an iteration ) eval_driver.run() print( 'step = {0}: Average Return = {1} Average Episode Length = {2}' .format(tf_agent.train_step_counter.numpy(), train_avg_reward.result(), train_avg_episode_len.result())) # play cartpole for the last 3 times and visualize import cv2 for _ in range(3): status = eval_env.reset() policy_state = tf_agent.policy.get_initial_state(eval_env.batch_size) while not status.is_last(): action = tf_agent.policy.action(status, policy_state) # NOTE: use greedy policy to test status = eval_env.step(action.action) policy_state = action.state cv2.imshow('halfcheetah', eval_env.pyenv.envs[0].render()) cv2.waitKey(25)
def train_eval( root_dir, strategy: tf.distribute.Strategy, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others policy_save_interval=10000, replay_buffer_save_interval=100000, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) _, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) actor_net = create_sequential_actor_network( actor_fc_layers=actor_fc_layers, action_tensor_spec=action_tensor_spec) critic_net = create_sequential_critic_network( obs_fc_layer_units=critic_obs_fc_layers, action_fc_layer_units=critic_action_fc_layers, joint_fc_layer_units=critic_joint_fc_layers) with strategy.scope(): train_step = train_utils.create_train_step() agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.keras.optimizers.Adam( learning_rate=actor_learning_rate), critic_optimizer=tf.keras.optimizers.Adam( learning_rate=critic_learning_rate), alpha_optimizer=tf.keras.optimizers.Adam( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_checkpoint_dir = os.path.join(root_dir, learner.TRAIN_DIR, learner.REPLAY_BUFFER_CHECKPOINT_DIR) reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer( path=reverb_checkpoint_dir) reverb_server = reverb.Server([table], port=reverb_port, checkpointer=reverb_checkpointer) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) def experience_dataset_fn(): return reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.ReverbCheckpointTrigger( train_step, interval=replay_buffer_save_interval, reverb_client=reverb_replay.py_client), # TODO(b/165023684): Add SIGTERM handler to checkpoint before preemption. triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers, strategy=strategy) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def load_environments(universe, env_name=None, domain_name=None, task_name=None, render_size=128, observation_render_size=64, observations_whitelist=None, action_repeat=1): """Loads train and eval environments. The universe can either be gym, in which case domain_name and task_name are ignored, or dm_control, in which case env_name is ignored. """ if universe == 'gym': tf.compat.v1.logging.info( 'Using environment {} from {} universe.'.format( env_name, universe)) gym_env_wrappers = [ functools.partial(gym_wrappers.RenderGymWrapper, render_kwargs={ 'height': render_size, 'width': render_size, 'device_id': 0 }), functools.partial(gym_wrappers.PixelObservationsGymWrapper, observations_whitelist=observations_whitelist, render_kwargs={ 'height': observation_render_size, 'width': observation_render_size, 'device_id': 0 }) ] eval_gym_env_wrappers = [ functools.partial(gym_wrappers.RenderGymWrapper, render_kwargs={ 'height': render_size, 'width': render_size, 'device_id': 1 }), # segfaults if the device is the same as train env functools.partial(gym_wrappers.PixelObservationsGymWrapper, observations_whitelist=observations_whitelist, render_kwargs={ 'height': observation_render_size, 'width': observation_render_size, 'device_id': 1 }) ] # segfaults if the device is the same as train env py_env = suite_mujoco.load(env_name, gym_env_wrappers=gym_env_wrappers) eval_py_env = suite_mujoco.load(env_name, gym_env_wrappers=eval_gym_env_wrappers) elif universe == 'dm_control': tf.compat.v1.logging.info( 'Using domain {} and task {} from {} universe.'.format( domain_name, task_name, universe)) render_kwargs = { 'height': render_size, 'width': render_size, 'camera_id': 0, } dm_env_wrappers = [ wrappers. FlattenObservationsWrapper, # combine position and velocity functools.partial( dm_control_wrappers.PixelObservationsDmControlWrapper, observations_whitelist=observations_whitelist, render_kwargs={ 'height': observation_render_size, 'width': observation_render_size, 'camera_id': 0 }) ] py_env = suite_dm_control.load(domain_name, task_name, render_kwargs=render_kwargs, env_wrappers=dm_env_wrappers) eval_py_env = suite_dm_control.load(domain_name, task_name, render_kwargs=render_kwargs, env_wrappers=dm_env_wrappers) else: raise ValueError('Invalid universe %s.' % universe) eval_py_env = video_wrapper.VideoWrapper(eval_py_env) if action_repeat > 1: py_env = wrappers.ActionRepeat(py_env, action_repeat) eval_py_env = wrappers.ActionRepeat(eval_py_env, action_repeat) return py_env, eval_py_env
def env_factory(env_name): py_env = suite_mujoco.load(env_name) tf_env = tf_py_environment.TFPyEnvironment(py_env) return tf_env
def main(_): tf.random.set_seed(FLAGS.seed) if FLAGS.models_dir is None: raise ValueError('You must set a value for models_dir.') env = suite_mujoco.load(FLAGS.env_name) env.seed(FLAGS.seed) env = tf_py_environment.TFPyEnvironment(env) sac = actor_lib.Actor(env.observation_spec().shape[0], env.action_spec()) model_filename = os.path.join(FLAGS.models_dir, 'DM-' + FLAGS.env_name, str(FLAGS.model_seed), '1000000') sac.load_weights(model_filename) if FLAGS.std is None: if 'Reacher' in FLAGS.env_name: std = 0.5 elif 'Ant' in FLAGS.env_name: std = 0.4 elif 'Walker' in FLAGS.env_name: std = 2.0 else: std = 0.75 else: std = FLAGS.std def get_action(state): _, action, log_prob = sac(state, std) return action, log_prob dataset = dict(model_filename=model_filename, behavior_std=std, trajectories=dict(states=[], actions=[], log_probs=[], next_states=[], rewards=[], masks=[])) for i in range(FLAGS.num_episodes): timestep = env.reset() trajectory = dict(states=[], actions=[], log_probs=[], next_states=[], rewards=[], masks=[]) while not timestep.is_last(): action, log_prob = get_action(timestep.observation) next_timestep = env.step(action) trajectory['states'].append(timestep.observation) trajectory['actions'].append(action) trajectory['log_probs'].append(log_prob) trajectory['next_states'].append(next_timestep.observation) trajectory['rewards'].append(next_timestep.reward) trajectory['masks'].append(next_timestep.discount) timestep = next_timestep for k, v in trajectory.items(): dataset['trajectories'][k].append(tf.concat(v, 0).numpy()) logging.info('%d trajectories', i + 1) data_save_dir = os.path.join(FLAGS.save_dir, FLAGS.env_name, str(FLAGS.model_seed)) if not tf.io.gfile.isdir(data_save_dir): tf.io.gfile.makedirs(data_save_dir) save_filename = os.path.join(data_save_dir, f'dualdice_{FLAGS.std}.pckl') with tf.io.gfile.GFile(save_filename, 'wb') as f: pickle.dump(dataset, f)
def main(_): tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) hparam_str = make_hparam_string(seed=FLAGS.seed, env_name=FLAGS.env_name) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) summary_writer.set_as_default() if FLAGS.d4rl: d4rl_env = gym.make(FLAGS.env_name) gym_spec = gym.spec(FLAGS.env_name) if gym_spec.max_episode_steps in [0, None]: # Add TimeLimit wrapper. gym_env = time_limit.TimeLimit(d4rl_env, max_episode_steps=1000) else: gym_env = d4rl_env gym_env.seed(FLAGS.seed) env = tf_py_environment.TFPyEnvironment( gym_wrapper.GymWrapper(gym_env)) behavior_dataset = D4rlDataset( d4rl_env, normalize_states=FLAGS.normalize_states, normalize_rewards=FLAGS.normalize_rewards, noise_scale=FLAGS.noise_scale, bootstrap=FLAGS.bootstrap) else: env = suite_mujoco.load(FLAGS.env_name) env.seed(FLAGS.seed) env = tf_py_environment.TFPyEnvironment(env) data_file_name = os.path.join( FLAGS.data_dir, FLAGS.env_name, '0', f'dualdice_{FLAGS.behavior_policy_std}.pckl') behavior_dataset = Dataset(data_file_name, FLAGS.num_trajectories, normalize_states=FLAGS.normalize_states, normalize_rewards=FLAGS.normalize_rewards, noise_scale=FLAGS.noise_scale, bootstrap=FLAGS.bootstrap) tf_dataset = behavior_dataset.with_uniform_sampling( FLAGS.sample_batch_size) tf_dataset_iter = iter(tf_dataset) if FLAGS.d4rl: with tf.io.gfile.GFile(FLAGS.d4rl_policy_filename, 'rb') as f: policy_weights = pickle.load(f) actor = utils.D4rlActor(env, policy_weights, is_dapg='dapg' in FLAGS.d4rl_policy_filename) else: actor = Actor(env.observation_spec().shape[0], env.action_spec()) actor.load_weights(behavior_dataset.model_filename) policy_returns = utils.estimate_monte_carlo_returns( env, FLAGS.discount, actor, FLAGS.target_policy_std, FLAGS.num_mc_episodes) logging.info('Estimated Per-Step Average Returns=%f', policy_returns) if 'fqe' in FLAGS.algo or 'dr' in FLAGS.algo: model = QFitter(env.observation_spec().shape[0], env.action_spec().shape[0], FLAGS.lr, FLAGS.weight_decay, FLAGS.tau) elif 'mb' in FLAGS.algo: model = ModelBased(env.observation_spec().shape[0], env.action_spec().shape[0], learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif 'dual_dice' in FLAGS.algo: model = DualDICE(env.observation_spec().shape[0], env.action_spec().shape[0], FLAGS.weight_decay) if 'iw' in FLAGS.algo or 'dr' in FLAGS.algo: behavior = BehaviorCloning(env.observation_spec().shape[0], env.action_spec(), FLAGS.lr, FLAGS.weight_decay) @tf.function def get_target_actions(states): return actor(tf.cast(behavior_dataset.unnormalize_states(states), env.observation_spec().dtype), std=FLAGS.target_policy_std)[1] @tf.function def get_target_logprobs(states, actions): log_probs = actor(tf.cast(behavior_dataset.unnormalize_states(states), env.observation_spec().dtype), actions=actions, std=FLAGS.target_policy_std)[2] if tf.rank(log_probs) > 1: log_probs = tf.reduce_sum(log_probs, -1) return log_probs min_reward = tf.reduce_min(behavior_dataset.rewards) max_reward = tf.reduce_max(behavior_dataset.rewards) min_state = tf.reduce_min(behavior_dataset.states, 0) max_state = tf.reduce_max(behavior_dataset.states, 0) @tf.function def update_step(): (states, actions, next_states, rewards, masks, weights, _) = next(tf_dataset_iter) initial_actions = get_target_actions(behavior_dataset.initial_states) next_actions = get_target_actions(next_states) if 'fqe' in FLAGS.algo or 'dr' in FLAGS.algo: model.update(states, actions, next_states, next_actions, rewards, masks, weights, FLAGS.discount, min_reward, max_reward) elif 'mb' in FLAGS.algo: model.update(states, actions, next_states, rewards, masks, weights) elif 'dual_dice' in FLAGS.algo: model.update(behavior_dataset.initial_states, initial_actions, behavior_dataset.initial_weights, states, actions, next_states, next_actions, masks, weights, FLAGS.discount) if 'iw' in FLAGS.algo or 'dr' in FLAGS.algo: behavior.update(states, actions, weights) gc.collect() for i in tqdm.tqdm(range(FLAGS.num_updates), desc='Running Training'): update_step() if i % FLAGS.eval_interval == 0: if 'fqe' in FLAGS.algo: pred_returns = model.estimate_returns( behavior_dataset.initial_states, behavior_dataset.initial_weights, get_target_actions) elif 'mb' in FLAGS.algo: pred_returns = model.estimate_returns( behavior_dataset.initial_states, behavior_dataset.initial_weights, get_target_actions, FLAGS.discount, min_reward, max_reward, min_state, max_state) elif FLAGS.algo in ['dual_dice']: pred_returns, pred_ratio = model.estimate_returns( iter(tf_dataset)) tf.summary.scalar('train/pred ratio', pred_ratio, step=i) elif 'iw' in FLAGS.algo or 'dr' in FLAGS.algo: discount = FLAGS.discount _, behavior_log_probs = behavior(behavior_dataset.states, behavior_dataset.actions) target_log_probs = get_target_logprobs( behavior_dataset.states, behavior_dataset.actions) offset = 0.0 rewards = behavior_dataset.rewards if 'dr' in FLAGS.algo: # Doubly-robust is effectively the same as importance-weighting but # transforming rewards at (s,a) to r(s,a) + gamma * V^pi(s') - # Q^pi(s,a) and adding an offset to each trajectory equal to V^pi(s0). offset = model.estimate_returns( behavior_dataset.initial_states, behavior_dataset.initial_weights, get_target_actions) q_values = (model(behavior_dataset.states, behavior_dataset.actions) / (1 - discount)) n_samples = 10 next_actions = [ get_target_actions(behavior_dataset.next_states) for _ in range(n_samples) ] next_q_values = sum([ model(behavior_dataset.next_states, next_action) / (1 - discount) for next_action in next_actions ]) / n_samples rewards = rewards + discount * next_q_values - q_values # Now we compute the self-normalized importance weights. # Self-normalization happens over trajectories per-step, so we # restructure the dataset as [num_trajectories, num_steps]. num_trajectories = len(behavior_dataset.initial_states) max_trajectory_length = np.max(behavior_dataset.steps) + 1 trajectory_weights = behavior_dataset.initial_weights trajectory_starts = np.where( np.equal(behavior_dataset.steps, 0))[0] batched_rewards = np.zeros( [num_trajectories, max_trajectory_length]) batched_masks = np.zeros( [num_trajectories, max_trajectory_length]) batched_log_probs = np.zeros( [num_trajectories, max_trajectory_length]) for traj_idx, traj_start in enumerate(trajectory_starts): traj_end = (trajectory_starts[traj_idx + 1] if traj_idx + 1 < len(trajectory_starts) else len(rewards)) traj_length = traj_end - traj_start batched_rewards[ traj_idx, :traj_length] = rewards[traj_start:traj_end] batched_masks[traj_idx, :traj_length] = 1. batched_log_probs[traj_idx, :traj_length] = ( -behavior_log_probs[traj_start:traj_end] + target_log_probs[traj_start:traj_end]) batched_weights = ( batched_masks * (discount**np.arange(max_trajectory_length))[None, :]) clipped_log_probs = np.clip(batched_log_probs, -6., 2.) cum_log_probs = batched_masks * np.cumsum(clipped_log_probs, axis=1) cum_log_probs_offset = np.max(cum_log_probs, axis=0) cum_probs = np.exp(cum_log_probs - cum_log_probs_offset[None, :]) avg_cum_probs = ( np.sum(cum_probs * trajectory_weights[:, None], axis=0) / (1e-10 + np.sum( batched_masks * trajectory_weights[:, None], axis=0))) norm_cum_probs = cum_probs / (1e-10 + avg_cum_probs[None, :]) weighted_rewards = batched_weights * batched_rewards * norm_cum_probs trajectory_values = np.sum(weighted_rewards, axis=1) avg_trajectory_value = ( (1 - discount) * np.sum(trajectory_values * trajectory_weights) / np.sum(trajectory_weights)) pred_returns = offset + avg_trajectory_value pred_returns = behavior_dataset.unnormalize_rewards(pred_returns) tf.summary.scalar('train/pred returns', pred_returns, step=i) logging.info('pred returns=%f', pred_returns) tf.summary.scalar('train/true minus pred returns', policy_returns - pred_returns, step=i) logging.info('true minus pred returns=%f', policy_returns - pred_returns)
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) base_env = suite_mujoco.load(FLAGS.env_name) if hasattr(base_env, 'max_episode_steps'): max_episode_steps = base_env.max_episode_steps else: logging.info('Unknown max episode steps. Setting to 1000.') max_episode_steps = 1000 env = base_env.gym env = wrappers.check_and_normalize_box_actions(env) env.seed(FLAGS.seed) eval_env = suite_mujoco.load(FLAGS.env_name).gym eval_env = wrappers.check_and_normalize_box_actions(eval_env) eval_env.seed(FLAGS.seed + 1) hparam_str_dict = dict(algo=FLAGS.algo, seed=FLAGS.seed, env=FLAGS.env_name, dqn=FLAGS.use_dqn) hparam_str = ','.join([ '%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys()) ]) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) rl_algo = algae.ALGAE(env.observation_space.shape[0], env.action_space.shape[0], [ float(env.action_space.low.min()), float(env.action_space.high.max()) ], FLAGS.log_interval, critic_lr=FLAGS.critic_lr, actor_lr=FLAGS.actor_lr, use_dqn=FLAGS.use_dqn, use_init_states=FLAGS.use_init_states, algae_alpha=FLAGS.algae_alpha, exponent=FLAGS.f_exponent) episode_return = 0 episode_timesteps = 0 done = True total_timesteps = 0 previous_time = time.time() replay_buffer = utils.ReplayBuffer(obs_shape=env.observation_space.shape, action_shape=env.action_space.shape, capacity=FLAGS.max_timesteps * 2, batch_size=FLAGS.sample_batch_size, device=device) log_dir = os.path.join(FLAGS.save_dir, 'logs') log_filename = os.path.join(log_dir, hparam_str) if not gfile.isdir(log_dir): gfile.mkdir(log_dir) eval_returns = [] with tqdm(total=FLAGS.max_timesteps, desc='') as pbar: # Final return is the average of the last 10 measurmenets. final_returns = collections.deque(maxlen=10) final_timesteps = 0 while total_timesteps < FLAGS.max_timesteps: _update_pbar_msg(pbar, total_timesteps) if done: print('episodic return: {}'.format(episode_return)) if episode_timesteps > 0: current_time = time.time() train_measurements = [ ('train/returns', episode_return), ('train/FPS', episode_timesteps / (current_time - previous_time)), ] _write_measurements(summary_writer, train_measurements, total_timesteps) obs = env.reset() episode_return = 0 episode_timesteps = 0 previous_time = time.time() #init_replay_buffer.add_batch(np.array([obs.astype(np.float32)])) if total_timesteps < FLAGS.num_random_actions: action = env.action_space.sample() else: action = rl_algo.act(obs, sample=True) if total_timesteps >= FLAGS.start_training_timesteps: with summary_writer.as_default(): target_entropy = (-env.action_space.shape[0] if FLAGS.target_entropy is None else FLAGS.target_entropy) for _ in range(FLAGS.num_updates_per_env_step): rl_algo.update( replay_buffer, total_timesteps=total_timesteps, discount=FLAGS.discount, tau=FLAGS.tau, target_entropy=target_entropy, actor_update_freq=FLAGS.actor_update_freq) next_obs, reward, done, _ = env.step(action) if (max_episode_steps is not None and episode_timesteps + 1 == max_episode_steps): done = True done_bool = 0 if episode_timesteps + 1 == max_episode_steps else float( done) replay_buffer.add(obs, action, reward, next_obs, done_bool) episode_return += reward episode_timesteps += 1 total_timesteps += 1 pbar.update(1) obs = next_obs if total_timesteps % FLAGS.eval_interval == 0: logging.info('Performing policy eval.') average_returns, evaluation_timesteps = evaluate( eval_env, rl_algo, max_episode_steps=max_episode_steps) eval_returns.append(average_returns) fin = gfile.GFile(log_filename, 'w') np.save(fin, np.array(eval_returns)) fin.close() eval_measurements = [ ('eval/average returns', average_returns), ('eval/average episode length', evaluation_timesteps), ] # TODO(sandrafaust) Make this average of the last N. final_returns.append(average_returns) final_timesteps = evaluation_timesteps _write_measurements(summary_writer, eval_measurements, total_timesteps) logging.info('Eval: ave returns=%f, ave episode length=%f', average_returns, evaluation_timesteps) # Final measurement. final_measurements = [ ('final/average returns', sum(final_returns) / len(final_returns)), ('final/average episode length', final_timesteps), ] _write_measurements(summary_writer, final_measurements, total_timesteps)
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others # Defaults to not checkpointing saved policy. If you wish to enable this, # please note the caveat explained in README.md. policy_save_interval=-1, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_tensor_spec, action_tensor_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss