def _run(self, strategy, batch_size=64, tf_function=True, replay_buffer_max_length=1000, train_steps=110, log_steps=10): """Runs Dqn CartPole environment. Args: strategy: Strategy to use, None is a valid value. batch_size: Total batch size to use for the run. tf_function: If True tf.function is used. replay_buffer_max_length: Max length of the replay buffer. train_steps: Number of steps to run. log_steps: How often to log step statistics, e.g. step time. """ obs_spec = array_spec.BoundedArraySpec([ 4, ], np.float32, -4., 4.) action_spec = array_spec.BoundedArraySpec((), np.int64, 0, 1) py_env = random_py_environment.RandomPyEnvironment( obs_spec, action_spec, batch_size=1, reward_fn=lambda *_: np.random.randint(1, 10, 1)) env = tf_py_environment.TFPyEnvironment(py_env) policy = random_tf_policy.RandomTFPolicy(env.time_step_spec(), env.action_spec()) with distribution_strategy_utils.strategy_scope_context(strategy): q_net = q_network.QNetwork(env.time_step_spec().observation, env.action_spec(), fc_layer_params=(100, )) tf_agent = dqn_agent.DqnAgent( env.time_step_spec(), env.action_spec(), q_network=q_net, optimizer=tf.keras.optimizers.Adam(), td_errors_loss_fn=common.element_wise_squared_loss) tf_agent.initialize() print(q_net.summary()) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_max_length) driver = dynamic_step_driver.DynamicStepDriver( env, policy, [replay_buffer.add_batch]) if tf_function: driver.run = common.function(driver.run) for _ in range(replay_buffer_max_length): driver.run() check_values = ['QNetwork/EncodingNetwork/dense/bias:0'] initial_values = utils.get_initial_values(tf_agent, check_values) with distribution_strategy_utils.strategy_scope_context(strategy): dataset = replay_buffer.as_dataset( num_parallel_calls=tf.data.experimental.AUTOTUNE, sample_batch_size=batch_size, num_steps=2) if strategy: iterator = iter( strategy.experimental_distribute_dataset(dataset)) else: iterator = iter(dataset) def train_step(inputs): experience, _ = inputs return tf_agent.train(experience) if tf_function: train_step = common.function(train_step) self.run_and_report(train_step, strategy, batch_size, train_steps=train_steps, log_steps=log_steps, iterator=iterator) utils.check_values_changed(tf_agent, initial_values, check_values)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) trajectory_spec = trajectory.from_transition( time_step=tf_env.time_step_spec(), action_step=policy_step.PolicyStep(action=tf_env.action_spec()), next_time_step=tf_env.time_step_spec()) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=trajectory_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) q_net = q_network.QNetwork(tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=fc_layer_params) tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839 epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy() collect_policy = tf_agent.collect_policy() collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) global_step = tf.compat.v1.train.get_or_create_global_step() initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() metrics = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(metrics, global_step.numpy()) time_step = None policy_state = () timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience, train_step_counter=global_step) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.contrib.summary.scalar(name='global_steps/sec', tensor=steps_per_sec) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: metrics = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(metrics, global_step.numpy()) return train_loss
agent = dqn_agent.DqnAgent(train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=train_step_counter) agent.initialize() # POLICIES eval_policy = agent.policy collect_policy = agent.collect_policy random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec()) example_environment = tf_py_environment.TFPyEnvironment( suite_gym.load('CartPole-v0')) time_step = example_environment.reset() random_policy.action(time_step) # Metrics and Evaluation #@test {"skip": true} def compute_avg_return(environment, policy, num_episodes=10): total_return = 0.0 for _ in range(num_episodes): time_step = environment.reset()
def train_eval( root_dir, env_name='HalfCheetah-v2', eval_env_name=None, env_load_fn=suite_mujoco.load, num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(eval_env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) train_step_counter = tf.Variable(0) agent = dqn_agent.DqnAgent(environment.time_step_spec(), environment.action_spec(), q_network=q_net, optimizer=optimizer, td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=train_step_counter) agent.initialize() #Policy random_policy = random_tf_policy.RandomTFPolicy(environment.time_step_spec(), environment.action_spec()) time_step = environment.reset() print(random_policy.action(time_step)) #Replay Buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=batch_size, max_length=replay_buffer_max_length) # print(compute_avg_return(environment, random_policy)) # time_step = environment.reset() # print(time_step)
base_dir = os.path.abspath( 'experiments/env_logs/playpen_reduced/symmetric/') env_log_dir = os.path.join(base_dir, 'rc_o/traj1/') # env = ResetFreeWrapper(env, reset_goal_frequency=500, full_reset_frequency=max_episode_steps) env = GoalTerminalResetWrapper( env, episodes_before_full_reset=max_episode_steps // 500, goal_reset_frequency=500) # env = Monitor(env, env_log_dir, video_callable=lambda x: x % 1 == 0, force=True) env = wrap_env(env) tf_env = tf_py_environment.TFPyEnvironment(env) tf_env.render = env.render time_step_spec = tf_env.time_step_spec() action_spec = tf_env.action_spec() policy = random_tf_policy.RandomTFPolicy(action_spec=action_spec, time_step_spec=time_step_spec) collect_data_spec = trajectory.Trajectory( step_type=time_step_spec.step_type, observation=time_step_spec.observation, action=action_spec, policy_info=policy.info_spec, next_step_type=time_step_spec.step_type, reward=time_step_spec.reward, discount=time_step_spec.discount) offline_data = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=collect_data_spec, batch_size=1, max_length=int(1e5)) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( env_log_dir, 'replay_buffer'), max_to_keep=10_000, replay_buffer=offline_data) rb_checkpointer.initialize_or_restore()
def train_eval( root_dir, env_name='MaskedCartPole-v0', num_iterations=100000, input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), train_sequence_length=10, # Params for collect initial_collect_steps=50, collect_episodes_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=10, batch_size=128, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=100, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): eval_py_env = suite_gym.load(env_name) tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) q_net = q_rnn_network.QRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps).run() collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run() # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) experience, _ = iterator.get_next() loss_info = common.function(tf_agent.train)(experience=experience) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) logging.info('Collecting initial experience.') sess.run(initial_collect_op) # Compute evaluation metrics. global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable(loss_info) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): # Train/collect/eval. start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, )
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=3000000, actor_fc_layers=(), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, dual_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=50000, policy_checkpoint_interval=50000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None, latent_dim=10, log_prob_reward_scale=0.0, predictor_updates_encoder=False, predict_prior=True, use_recurrent_actor=False, rnn_sequence_length=20, clip_max_stddev=10.0, clip_min_stddev=0.1, clip_mean=30.0, predictor_num_layers=2, use_identity_encoder=False, identity_encoder_single_stddev=False, kl_constraint=1.0, eval_dropout=(), use_residual_predictor=True, gym_kwargs=None, predict_prior_std=True, random_seed=0, ): """A simple train and eval for SAC.""" np.random.seed(random_seed) tf.random.set_seed(random_seed) if use_recurrent_actor: batch_size = batch_size // rnn_sequence_length root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): _build_env = functools.partial( suite_gym.load, environment_name=env_name, # pylint: disable=invalid-name gym_env_wrappers=(), gym_kwargs=gym_kwargs) tf_env = tf_py_environment.TFPyEnvironment(_build_env()) eval_vec = [] # (name, env, metrics) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes) ] eval_tf_env = tf_py_environment.TFPyEnvironment(_build_env()) name = '' eval_vec.append((name, eval_tf_env, eval_metrics)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() if latent_dim == 'obs': latent_dim = observation_spec.shape[0] def _activation(t): t1, t2 = tf.split(t, 2, axis=1) low = -np.inf if clip_mean is None else -clip_mean high = np.inf if clip_mean is None else clip_mean t1 = rpc_utils.squash_to_range(t1, low, high) if clip_min_stddev is None: low = -np.inf else: low = tf.math.log(tf.exp(clip_min_stddev) - 1.0) if clip_max_stddev is None: high = np.inf else: high = tf.math.log(tf.exp(clip_max_stddev) - 1.0) t2 = rpc_utils.squash_to_range(t2, low, high) return tf.concat([t1, t2], axis=1) if use_identity_encoder: assert latent_dim == observation_spec.shape[0] obs_input = tf.keras.layers.Input(observation_spec.shape) zeros = 0.0 * obs_input[:, :1] stddev_dim = 1 if identity_encoder_single_stddev else latent_dim pre_stddev = tf.keras.layers.Dense(stddev_dim, activation=None)(zeros) ones = zeros + tf.ones((1, latent_dim)) pre_stddev = pre_stddev * ones # Multiply to broadcast to latent_dim. pre_mean_stddev = tf.concat([obs_input, pre_stddev], axis=1) output = tfp.layers.IndependentNormal(latent_dim)(pre_mean_stddev) encoder_net = tf.keras.Model(inputs=obs_input, outputs=output) else: encoder_net = tf.keras.Sequential([ tf.keras.layers.Dense(256, activation='relu'), tf.keras.layers.Dense(256, activation='relu'), tf.keras.layers.Dense( tfp.layers.IndependentNormal.params_size(latent_dim), activation=_activation, kernel_initializer='glorot_uniform'), tfp.layers.IndependentNormal(latent_dim), ]) # Build the predictor net obs_input = tf.keras.layers.Input(observation_spec.shape) action_input = tf.keras.layers.Input(action_spec.shape) class ConstantIndependentNormal(tfp.layers.IndependentNormal): """A keras layer that always returns N(0, 1) distribution.""" def call(self, inputs): loc_scale = tf.concat([ tf.zeros((latent_dim, )), tf.fill((latent_dim, ), tf.math.log(tf.exp(1.0) - 1)) ], axis=0) # Multiple by [B x 1] tensor to broadcast batch dimension. loc_scale = loc_scale * tf.ones_like(inputs[:, :1]) return super(ConstantIndependentNormal, self).call(loc_scale) if predict_prior: z = encoder_net(obs_input) if not predictor_updates_encoder: z = tf.stop_gradient(z) za = tf.concat([z, action_input], axis=1) if use_residual_predictor: za_input = tf.keras.layers.Input(za.shape[1]) loc_scale = tf.keras.Sequential( predictor_num_layers * [tf.keras.layers.Dense(256, activation='relu')] + [ # pylint: disable=line-too-long tf.keras.layers.Dense(tfp.layers.IndependentNormal. params_size(latent_dim), activation=_activation, kernel_initializer='zeros'), ])(za_input) if predict_prior_std: combined_loc_scale = tf.concat([ loc_scale[:, :latent_dim] + za_input[:, :latent_dim], loc_scale[:, latent_dim:] ], axis=1) else: # Note that softplus(log(e - 1)) = 1. combined_loc_scale = tf.concat([ loc_scale[:, :latent_dim] + za_input[:, :latent_dim], tf.math.log(np.e - 1) * tf.ones_like(loc_scale[:, latent_dim:]) ], axis=1) dist = tfp.layers.IndependentNormal(latent_dim)( combined_loc_scale) output = tf.keras.Model(inputs=za_input, outputs=dist)(za) else: assert predict_prior_std output = tf.keras.Sequential( predictor_num_layers * [tf.keras.layers.Dense(256, activation='relu')] + # pylint: disable=line-too-long [ tf.keras.layers.Dense(tfp.layers.IndependentNormal. params_size(latent_dim), activation=_activation, kernel_initializer='zeros'), tfp.layers.IndependentNormal(latent_dim), ])(za) else: # scale is chosen by inverting the softplus function to equal 1. if len(obs_input.shape) > 2: input_reshaped = tf.reshape( obs_input, [-1, tf.math.reduce_prod(obs_input.shape[1:])]) # Multiply by [B x 1] tensor to broadcast batch dimension. za = tf.zeros(latent_dim + action_spec.shape[0], ) * tf.ones_like(input_reshaped[:, :1]) # pylint: disable=line-too-long else: # Multiple by [B x 1] tensor to broadcast batch dimension. za = tf.zeros(latent_dim + action_spec.shape[0], ) * tf.ones_like(obs_input[:, :1]) # pylint: disable=line-too-long output = tf.keras.Sequential([ ConstantIndependentNormal(latent_dim), ])(za) predictor_net = tf.keras.Model(inputs=(obs_input, action_input), outputs=output) if use_recurrent_actor: ActorClass = rpc_utils.RecurrentActorNet # pylint: disable=invalid-name else: ActorClass = rpc_utils.ActorNet # pylint: disable=invalid-name actor_net = ActorClass(input_tensor_spec=observation_spec, output_tensor_spec=action_spec, encoder=encoder_net, predictor=predictor_net, fc_layers=actor_fc_layers) critic_net = rpc_utils.CriticNet( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') critic_net_2 = None target_critic_net_1 = None target_critic_net_2 = None tf_agent = rpc_agent.RpAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, critic_network_2=critic_net_2, target_critic_network=target_critic_net_1, target_critic_network_2=target_critic_net_2, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) dual_optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=dual_learning_rate) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] kl_metric = rpc_utils.AverageKLMetric(encoder=encoder_net, predictor=predictor_net, batch_size=tf_env.batch_size) eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy checkpoint_items = { 'ckpt_dir': train_dir, 'agent': tf_agent, 'global_step': global_step, 'metrics': metric_utils.MetricsGroup(train_metrics, 'train_metrics'), 'dual_optimizer': dual_optimizer, } train_checkpointer = common.Checkpointer(**checkpoint_items) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps, transition_observers=[kl_metric]) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration, transition_observers=[kl_metric]) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if replay_buffer.num_frames() == 0: # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps ' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() for name, eval_tf_env, eval_metrics in eval_vec: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics-%s' % name, ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics, prefix=name) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 train_time_acc = 0 env_time_acc = 0 if use_recurrent_actor: # default from sac/train_eval_rnn.py num_steps = rnn_sequence_length + 1 def _filter_invalid_transition(trajectories, unused_arg1): return tf.reduce_all(~trajectories.is_boundary()[:-1]) tf_agent._as_transition = data_converter.AsTransition( # pylint: disable=protected-access tf_agent.data_context, squeeze_time_dim=False) else: num_steps = 2 def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=num_steps).unbatch().filter(_filter_invalid_transition) dataset = dataset.batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) @tf.function def train_step(): experience, _ = next(iterator) prior = predictor_net( (experience.observation[:, 0], experience.action[:, 0]), training=False) z_next = encoder_net(experience.observation[:, 1], training=False) # predictor_kl is a vector of size batch_size. predictor_kl = tfp.distributions.kl_divergence(z_next, prior) with tf.GradientTape() as tape: tape.watch(actor_net._log_kl_coefficient) # pylint: disable=protected-access dual_loss = -1.0 * actor_net._log_kl_coefficient * ( # pylint: disable=protected-access tf.stop_gradient(tf.reduce_mean(predictor_kl)) - kl_constraint) dual_grads = tape.gradient(dual_loss, [actor_net._log_kl_coefficient]) # pylint: disable=protected-access grads_and_vars = list( zip(dual_grads, [actor_net._log_kl_coefficient])) # pylint: disable=protected-access dual_optimizer.apply_gradients(grads_and_vars) # Clip the dual variable so exp(log_kl_coef) <= 1e6. log_kl_coef = tf.clip_by_value( actor_net._log_kl_coefficient, # pylint: disable=protected-access -1.0 * np.log(1e6), np.log(1e6)) actor_net._log_kl_coefficient.assign(log_kl_coef) # pylint: disable=protected-access with tf.name_scope('dual_loss'): tf.compat.v2.summary.scalar(name='dual_loss', data=tf.reduce_mean(dual_loss), step=global_step) tf.compat.v2.summary.scalar( name='log_kl_coefficient', data=actor_net._log_kl_coefficient, # pylint: disable=protected-access step=global_step) z_entropy = z_next.entropy() log_prob = prior.log_prob(z_next.sample()) with tf.name_scope('rp-metrics'): common.generate_tensor_summaries('predictor_kl', predictor_kl, global_step) common.generate_tensor_summaries('z_entropy', z_entropy, global_step) common.generate_tensor_summaries('log_prob', log_prob, global_step) common.generate_tensor_summaries('z_mean', z_next.mean(), global_step) common.generate_tensor_summaries('z_stddev', z_next.stddev(), global_step) common.generate_tensor_summaries('prior_mean', prior.mean(), global_step) common.generate_tensor_summaries('prior_stddev', prior.stddev(), global_step) if log_prob_reward_scale == 'auto': coef = tf.stop_gradient(tf.exp(actor_net._log_kl_coefficient)) # pylint: disable=protected-access else: coef = log_prob_reward_scale tf.debugging.check_numerics(tf.reduce_mean(predictor_kl), 'predictor_kl is inf or nan.') tf.debugging.check_numerics(coef, 'coef is inf or nan.') new_reward = experience.reward - coef * predictor_kl[:, None] experience = experience._replace(reward=new_reward) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) # Save the hyperparameters operative_filename = os.path.join(root_dir, 'operative.gin') with tf.compat.v1.gfile.Open(operative_filename, 'w') as f: f.write(gin.operative_config_str()) print(gin.operative_config_str()) global_step_val = global_step.numpy() while global_step_val < num_iterations: start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) env_time_acc += time.time() - start_time train_start_time = time.time() for _ in range(train_steps_per_iteration): train_loss = train_step() train_time_acc += time.time() - train_start_time time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) train_steps_per_sec = (global_step_val - timed_at_step) / train_time_acc logging.info('Train: %.3f steps/sec', train_steps_per_sec) tf.compat.v2.summary.scalar(name='train_steps_per_sec', data=train_steps_per_sec, step=global_step) env_steps_per_sec = (global_step_val - timed_at_step) / env_time_acc logging.info('Env: %.3f steps/sec', env_steps_per_sec) tf.compat.v2.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 train_time_acc = 0 env_time_acc = 0 for train_metric in train_metrics + [kl_metric]: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step_val % eval_interval == 0: start_time = time.time() for name, eval_tf_env, eval_metrics in eval_vec: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics-%s' % name, ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics, prefix=name) logging.info('Evaluation: %d min', (time.time() - start_time) / 60) for prob_dropout in eval_dropout: rpc_utils.eval_dropout_fn(eval_tf_env, actor_net, global_step, prob_dropout=prob_dropout) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='sawyer_reach', num_iterations=3000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, gamma=0.99, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=200000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, random_seed=0, max_future_steps=50, actor_std=None, log_subset=None, ): """A simple train and eval for SAC.""" np.random.seed(random_seed) tf.random.set_seed(random_seed) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env, eval_tf_env, obs_dim = c_learning_envs.load(env_name) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() if actor_std is None: proj_net = tanh_normal_projection_network.TanhNormalProjectionNetwork else: proj_net = functools.partial( tanh_normal_projection_network.TanhNormalProjectionNetwork, std_transform=lambda t: actor_std * tf.ones_like(t)) actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=proj_net) critic_net = c_learning_utils.ClassifierCriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = c_learning_agent.CLearningAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=bce_loss, gamma=gamma, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes), c_learning_utils.FinalDistance(buffer_size=num_eval_episodes, obs_dim=obs_dim), c_learning_utils.MinimumDistance(buffer_size=num_eval_episodes, obs_dim=obs_dim), c_learning_utils.DeltaDistance(buffer_size=num_eval_episodes, obs_dim=obs_dim), ] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), c_learning_utils.InitialDistance(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, obs_dim=obs_dim), c_learning_utils.FinalDistance(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, obs_dim=obs_dim), c_learning_utils.MinimumDistance(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, obs_dim=obs_dim), c_learning_utils.DeltaDistance(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, obs_dim=obs_dim), ] if log_subset is not None: start_index, end_index = log_subset for name, metrics in [('train', train_metrics), ('eval', eval_metrics)]: metrics.extend([ c_learning_utils.InitialDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size if name == 'train' else 10, obs_dim=obs_dim, start_index=start_index, end_index=end_index, name='SubsetInitialDistance'), c_learning_utils.FinalDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size if name == 'train' else 10, obs_dim=obs_dim, start_index=start_index, end_index=end_index, name='SubsetFinalDistance'), c_learning_utils.MinimumDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size if name == 'train' else 10, obs_dim=obs_dim, start_index=start_index, end_index=end_index, name='SubsetMinimumDistance'), c_learning_utils.DeltaDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size if name == 'train' else 10, obs_dim=obs_dim, start_index=start_index, end_index=end_index, name='SubsetDeltaDistance'), ]) eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=None) train_checkpointer.initialize_or_restore() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Save the hyperparameters operative_filename = os.path.join(root_dir, 'operative.gin') with tf.compat.v1.gfile.Open(operative_filename, 'w') as f: f.write(gin.operative_config_str()) logging.info(gin.operative_config_str()) if replay_buffer.num_frames() == 0: # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps ' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset(sample_batch_size=batch_size, num_steps=max_future_steps) dataset = dataset.unbatch().filter(_filter_invalid_transition) dataset = dataset.batch(batch_size, drop_remainder=True) goal_fn = functools.partial(c_learning_utils.goal_fn, batch_size=batch_size, obs_dim=obs_dim, gamma=gamma) dataset = dataset.map(goal_fn) dataset = dataset.prefetch(5) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) global_step_val = global_step.numpy() while global_step_val < num_iterations: start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) return train_loss
def main(_): if FLAGS.eager: tf.config.experimental_run_functions_eagerly(FLAGS.eager) tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) action_repeat = FLAGS.action_repeat _, _, domain_name, _ = FLAGS.env_name.split('-') if domain_name in ['cartpole']: FLAGS.set_default('action_repeat', 8) elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']: FLAGS.set_default('action_repeat', 4) elif domain_name in ['finger', 'walker']: FLAGS.set_default('action_repeat', 2) FLAGS.set_default('max_timesteps', FLAGS.max_timesteps // FLAGS.action_repeat) env = utils.load_env( FLAGS.env_name, FLAGS.seed, action_repeat, FLAGS.frame_stack) eval_env = utils.load_env( FLAGS.env_name, FLAGS.seed, action_repeat, FLAGS.frame_stack) is_image_obs = (isinstance(env.observation_spec(), TensorSpec) and len(env.observation_spec().shape) == 3) spec = ( env.observation_spec(), env.action_spec(), env.reward_spec(), env.reward_spec(), # discount spec env.observation_spec() # next observation spec ) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( spec, batch_size=1, max_length=FLAGS.max_length_replay_buffer) @tf.function def add_to_replay(state, action, reward, discount, next_states): replay_buffer.add_batch((state, action, reward, discount, next_states)) hparam_str = utils.make_hparam_string( FLAGS.xm_parameters, seed=FLAGS.seed, env_name=FLAGS.env_name, algo_name=FLAGS.algo_name) summary_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) results_writer = tf.summary.create_file_writer( os.path.join(FLAGS.save_dir, 'results', hparam_str)) if 'ddpg' in FLAGS.algo_name: model = ddpg.DDPG( env.observation_spec(), env.action_spec(), cross_norm='crossnorm' in FLAGS.algo_name) elif 'crr' in FLAGS.algo_name: model = awr.AWR( env.observation_spec(), env.action_spec(), f='bin_max') elif 'awr' in FLAGS.algo_name: model = awr.AWR( env.observation_spec(), env.action_spec(), f='exp_mean') elif 'sac_v1' in FLAGS.algo_name: model = sac_v1.SAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'asac' in FLAGS.algo_name: model = asac.ASAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) elif 'sac' in FLAGS.algo_name: model = sac.SAC( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0], cross_norm='crossnorm' in FLAGS.algo_name, pcl_actor_update='pc' in FLAGS.algo_name) elif 'pcl' in FLAGS.algo_name: model = pcl.PCL( env.observation_spec(), env.action_spec(), target_entropy=-env.action_spec().shape[0]) initial_collect_policy = random_tf_policy.RandomTFPolicy( env.time_step_spec(), env.action_spec()) dataset = replay_buffer.as_dataset( num_parallel_calls=tf.data.AUTOTUNE, sample_batch_size=FLAGS.sample_batch_size) if is_image_obs: # Augment images as in DRQ. dataset = dataset.map(image_aug, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False).prefetch(3) else: dataset = dataset.prefetch(3) def repack(*data): return data[0] dataset = dataset.map(repack) replay_buffer_iter = iter(dataset) previous_time = time.time() timestep = env.reset() episode_return = 0 episode_timesteps = 0 step_mult = 1 if action_repeat < 1 else action_repeat for i in tqdm.tqdm(range(FLAGS.max_timesteps)): if i % FLAGS.deployment_batch_size == 0: for _ in range(FLAGS.deployment_batch_size): if timestep.is_last(): if episode_timesteps > 0: current_time = time.time() with summary_writer.as_default(): tf.summary.scalar( 'train/returns', episode_return, step=(i + 1) * step_mult) tf.summary.scalar( 'train/FPS', episode_timesteps / (current_time - previous_time), step=(i + 1) * step_mult) timestep = env.reset() episode_return = 0 episode_timesteps = 0 previous_time = time.time() if (replay_buffer.num_frames() < FLAGS.num_random_actions or replay_buffer.num_frames() < FLAGS.deployment_batch_size): # Use policy only after the first deployment. policy_step = initial_collect_policy.action(timestep) action = policy_step.action else: action = model.actor(timestep.observation, sample=True) next_timestep = env.step(action) add_to_replay(timestep.observation, action, next_timestep.reward, next_timestep.discount, next_timestep.observation) episode_return += next_timestep.reward[0] episode_timesteps += 1 timestep = next_timestep if i + 1 >= FLAGS.start_training_timesteps: with summary_writer.as_default(): info_dict = model.update_step(replay_buffer_iter) if (i + 1) % FLAGS.log_interval == 0: with summary_writer.as_default(): for k, v in info_dict.items(): tf.summary.scalar(f'training/{k}', v, step=(i + 1) * step_mult) if (i + 1) % FLAGS.eval_interval == 0: logging.info('Performing policy eval.') average_returns, evaluation_timesteps = evaluation.evaluate( eval_env, model) with results_writer.as_default(): tf.summary.scalar( 'evaluation/returns', average_returns, step=(i + 1) * step_mult) tf.summary.scalar( 'evaluation/length', evaluation_timesteps, step=(i+1) * step_mult) logging.info('Eval at %d: ave returns=%f, ave episode length=%f', (i + 1) * step_mult, average_returns, evaluation_timesteps) if (i + 1) % FLAGS.eval_interval == 0: model.save_weights( os.path.join(FLAGS.save_dir, 'results', FLAGS.env_name + '__' + str(i + 1)))
def train_eval(root_dir, env_name="CartPole-v0", agent_class=Agent, num_iterations=10000, initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=10000, train_steps_per_iteration=1, batch_size=32): global_step = tf.compat.v1.train.get_or_create_global_step() tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_py_env = suite_gym.load(env_name) network = Network(input_tensor_spec=tf_env.time_step_spec().observation, action_spec=tf_env.action_spec()) tf_agent = agent_class(time_step_spec=tf_env.time_step_spec(), action_spec=tf_env.action_spec(), network=network, optimizer=tf.compat.v1.train.AdamOptimizer(), epsilon_greedy=epsilon_greedy) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) replay_observer = [replay_buffer.add_batch] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps).run() collect_policy = tf_agent.collect_policy collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer, num_steps=collect_steps_per_iteration).run() dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) experience, _ = iterator.get_next() train_op = common.function(tf_agent.train)(experience=experience) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: sess.run(iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(initial_collect_op) global_step_val = sess.run(global_step) collect_call = sess.make_callable(collect_op) global_step_call = sess.make_callable(global_step) train_step_call = sess.make_callable(train_op) for _ in range(num_iterations): collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() global_step_val = global_step_call() logging.info("step = %d, loss = %d", global_step_val, loss_info_value)
def train( root_dir, load_root_dir=None, env_load_fn=None, env_name=None, num_parallel_environments=1, # pylint: disable=unused-argument agent_class=None, initial_collect_random=True, # pylint: disable=unused-argument initial_collect_driver_class=None, collect_driver_class=None, num_global_steps=1000000, train_steps_per_iteration=1, train_metrics=None, # Safety Critic training args train_sc_steps=10, train_sc_interval=300, online_critic=False, # Params for eval run_eval=False, num_eval_episodes=30, eval_interval=1000, eval_metrics_callback=None, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, keep_rb_checkpoint=False, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, early_termination_fn=None, env_metric_factories=None): # pylint: disable=unused-argument """A simple train and eval for SC-SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() train_metrics = train_metrics or [] if run_eval: eval_dir = os.path.join(root_dir, 'eval') eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = env_load_fn(env_name) if not isinstance(tf_env, tf_py_environment.TFPyEnvironment): tf_env = tf_py_environment.TFPyEnvironment(tf_env) if run_eval: eval_py_env = env_load_fn(env_name) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() print('obs spec:', observation_spec) print('action spec:', action_spec) if online_critic: resample_metric = tf_py_metric.TfPyMetric( py_metrics.CounterMetric('unsafe_ac_samples')) tf_agent = agent_class(time_step_spec, action_spec, train_step_counter=global_step, resample_metric=resample_metric) else: tf_agent = agent_class(time_step_spec, action_spec, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. collect_data_spec = tf_agent.collect_data_spec logging.info('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, max_length=1000000) logging.info('RB capacity: %i', replay_buffer.capacity) logging.info('ReplayBuffer Collect data spec: %s', collect_data_spec) agent_observers = [replay_buffer.add_batch] if online_critic: online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, max_length=10000) online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer') online_rb_checkpointer = common.Checkpointer( ckpt_dir=online_rb_ckpt_dir, max_to_keep=1, replay_buffer=online_replay_buffer) clear_rb = common.function(online_replay_buffer.clear) agent_observers.append(online_replay_buffer.add_batch) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] if not online_critic: eval_policy = tf_agent.policy else: eval_policy = tf_agent._safe_policy # pylint: disable=protected-access initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) if not online_critic: collect_policy = tf_agent.collect_policy else: collect_policy = tf_agent._safe_policy # pylint: disable=protected-access train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) safety_critic_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'safety_critic'), safety_critic=tf_agent._safety_critic_network, # pylint: disable=protected-access global_step=global_step) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) if load_root_dir: load_root_dir = os.path.expanduser(load_root_dir) load_train_dir = os.path.join(load_root_dir, 'train') misc.load_pi_ckpt(load_train_dir, tf_agent) # loads tf_agent if load_root_dir is None: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() safety_critic_checkpointer.initialize_or_restore() collect_driver = collect_driver_class(tf_env, collect_policy, observers=agent_observers + train_metrics) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if not rb_checkpointer.checkpoint_exists: logging.info('Performing initial collection ...') common.function( initial_collect_driver_class(tf_env, initial_collect_policy, observers=agent_observers + train_metrics).run)() last_id = replay_buffer._get_last_id() # pylint: disable=protected-access logging.info('Data saved after initial collection: %d steps', last_id) tf.print( replay_buffer._get_rows_for_id(last_id), # pylint: disable=protected-access output_stream=logging.info) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if FLAGS.viz_pm: eval_fig_dir = osp.join(eval_dir, 'figs') if not tf.io.gfile.isdir(eval_fig_dir): tf.io.gfile.makedirs(eval_fig_dir) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2).prefetch(3) iterator = iter(dataset) if online_critic: online_dataset = online_replay_buffer.as_dataset( num_parallel_calls=3, num_steps=2).prefetch(3) online_iterator = iter(online_dataset) @common.function def critic_train_step(): """Builds critic training step.""" experience, buf_info = next(online_iterator) if env_name in [ 'IndianWell', 'IndianWell2', 'IndianWell3', 'DrunkSpider', 'DrunkSpiderShort' ]: safe_rew = experience.observation['task_agn_rew'] else: safe_rew = agents.process_replay_buffer( online_replay_buffer, as_tensor=True) safe_rew = tf.gather(safe_rew, tf.squeeze(buf_info.ids), axis=1) ret = tf_agent.train_sc(experience, safe_rew) clear_rb() return ret @common.function def train_step(): experience, _ = next(iterator) ret = tf_agent.train(experience) return ret if not early_termination_fn: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss') if online_critic: mean_resample_ac = tf.keras.metrics.Mean( name='mean_unsafe_ac_samples') resample_metric.reset() while (global_step.numpy() <= num_global_steps and not early_termination_fn()): # Collect and train. start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) if online_critic: mean_resample_ac(resample_metric.result()) resample_metric.reset() if time_step.is_last(): resample_ac_freq = mean_resample_ac.result() mean_resample_ac.reset_states() tf.compat.v2.summary.scalar(name='unsafe_ac_samples', data=resample_ac_freq, step=global_step) for _ in range(train_steps_per_iteration): train_loss = train_step() mean_train_loss(train_loss.loss) if online_critic: if global_step.numpy() % train_sc_interval == 0: for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # pylint: disable=unused-variable total_loss = mean_train_loss.result() mean_train_loss.reset_states() # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True break else: loss_divergence_counter = 0 time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), total_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) safety_critic_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: if online_critic: online_rb_checkpointer.save(global_step=global_step_val) rb_checkpointer.save(global_step=global_step_val) if run_eval and global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if FLAGS.viz_pm: savepath = 'step{}.png'.format(global_step_val) savepath = osp.join(eval_fig_dir, savepath) misc.record_episode_vis_summary(eval_tf_env, eval_policy, savepath) if not keep_rb_checkpoint: misc.cleanup_checkpoints(rb_ckpt_dir) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy())) return total_loss
## Driver and Observer # Driver is responsible collecting trajectories from the environment trainMetrics = [ tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric() ] collectDriver = dynamic_step_driver.DynamicStepDriver( trainEnv, tfAgent.collect_policy, observers=[replayBufferObserver] + trainMetrics, num_steps=5048) ## Initialize Training Environment # Utilizing a random policy initialCollectPolicy = random_tf_policy.RandomTFPolicy( trainEnv.time_step_spec(), trainEnv.action_spec()) initDriver = dynamic_step_driver.DynamicStepDriver( trainEnv, initialCollectPolicy, observers=[replayBuffer.add_batch], num_steps=int(2e4)) finalTimeStep, finalPolicyState = initDriver.run() dataset = replayBuffer.as_dataset(sample_batch_size=64, num_steps=2, num_parallel_calls=3).prefetch(3) tfAgent = train_agent(int(1.5e6), tfAgent, trainEnv, collectDriver, trainMetrics, dataset)
def train(): num_iterations=1000000 # Params for networks. actor_fc_layers=(128, 64) actor_output_fc_layers=(64,) actor_lstm_size=(32,) critic_obs_fc_layers=None critic_action_fc_layers=None critic_joint_fc_layers=(128,) critic_output_fc_layers=(64,) critic_lstm_size=(32,) num_parallel_environments=1 # Params for collect initial_collect_episodes=1 collect_episodes_per_iteration=1 replay_buffer_capacity=1000000 # Params for target update target_update_tau=0.05 target_update_period=5 # Params for train train_steps_per_iteration=1 batch_size=256 critic_learning_rate=3e-4 train_sequence_length=20 actor_learning_rate=3e-4 alpha_learning_rate=3e-4 td_errors_loss_fn=tf.math.squared_difference gamma=0.99 reward_scale_factor=0.1 gradient_clipping=None use_tf_functions=True # Params for eval num_eval_episodes=30 eval_interval=10000 debug_summaries=False summarize_grads_and_vars=False global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=tanh_normal_projection_network .TanhNormalProjectionNetwork) critic_net = critic_rnn_network.CriticRnnNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] env_steps = tf_metrics.EnvironmentSteps(prefix='Train') eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=replay_observer, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_observer, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if env_steps.result() == 0 or replay_buffer.num_frames() == 0: logging.info( 'Initializing replay buffer by collecting experience for %d episodes ' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) time_acc = 0 env_steps_before = env_steps.result().numpy() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): # Reduce filter_fn over full trajectory sampled. The sequence is kept only # if all elements except for the last one pass the filter. This is to # allow training on terminal steps. return tf.reduce_all(~trajectories.is_boundary()[:-1]) dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=train_sequence_length+1).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_env_steps = env_steps.result() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) episode_steps = env_steps.result() - start_env_steps # TODO(b/152648849) for _ in range(episode_steps): for _ in range(train_steps_per_iteration): train_step()
def train_eval( root_dir, env_name="HalfCheetah-v2", num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=100000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=500000000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, relabel_type=None, num_future_states=4, max_episode_steps=100, random_seed=0, eval_task_list=None, constant_task=None, # Whether to train on a single task clip_critic=None, ): """A simple train and eval for SAC.""" np.random.seed(random_seed) if relabel_type == "none": relabel_type = None assert relabel_type in [None, "future", "last", "soft", "random"] if constant_task: assert relabel_type is None if eval_task_list is None: eval_task_list = [] root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, "train") eval_dir = os.path.join(root_dir, "eval") train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ utils.AverageSuccessMetric(max_episode_steps=max_episode_steps, buffer_size=num_eval_episodes), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env, task_distribution = utils.get_env(env_name, constant_task=constant_task) eval_tf_env, _ = utils.get_env(env_name, max_episode_steps, constant_task=constant_task) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=utils.normal_projection_net, ) if isinstance(clip_critic, float): output_activation_fn = lambda x: clip_critic * tf.sigmoid(x) elif isinstance(clip_critic, tuple): assert len(clip_critic) == 2 min_val, max_val = clip_critic output_activation_fn = ( lambda x: # pylint: disable=g-long-lambda (max_val - min_val) * tf.sigmoid(x) + min_val) else: output_activation_fn = None critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, output_activation_fn=output_activation_fn, ) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() # Make the replay buffer. replay_buffer = relabelling_replay_buffer.GoalRelabellingReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, task_distribution=task_distribution, actor=actor_net, critic=critic_net, gamma=gamma, relabel_type=relabel_type, sample_batch_size=batch_size, num_parallel_calls=tf.data.experimental.AUTOTUNE, num_future_states=num_future_states, ) env_steps = tf_metrics.EnvironmentSteps(prefix="Train") train_metrics = [ tf_metrics.NumberOfEpisodes(prefix="Train"), env_steps, utils.AverageSuccessMetric( prefix="Train", max_episode_steps=max_episode_steps, buffer_size=num_eval_episodes, ), tf_metrics.AverageReturnMetric( prefix="Train", buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, ), tf_metrics.AverageEpisodeLengthMetric( prefix="Train", buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, ), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, "train_metrics"), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "policy"), policy=eval_policy, global_step=global_step, ) train_checkpointer.initialize_or_restore() data_collector = utils.DataCollector( tf_env, tf_agent.collect_policy, replay_buffer, max_episode_steps=max_episode_steps, observers=train_metrics, ) if use_tf_functions: tf_agent.train = common.function(tf_agent.train) else: tf.config.experimental_run_functions_eagerly(True) # Save the config string as late as possible to catch # as many object instantiations as possible. config_str = gin.operative_config_str() logging.info(config_str) with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"), "w") as f: f.write(config_str) # Collect initial replay data. logging.info( "Initializing replay buffer by collecting experience for %d steps with " "a random policy.", initial_collect_steps, ) for _ in range(initial_collect_steps): data_collector.step(initial_collect_policy) data_collector.reset() logging.info("Replay buffer initial size: %d", replay_buffer.num_frames()) logging.info("Computing initial eval metrics") for task in [None] + eval_task_list: with utils.FixedTask(eval_tf_env, task): prefix = "Metrics" if task is None else "Metrics-%s" % str( task) metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix=prefix, ) metric_utils.log_metrics(eval_metrics) time_acc = 0 env_time_acc = 0 train_time_acc = 0 env_steps_before = env_steps.result().numpy() if use_tf_functions: tf_agent.train = common.function(tf_agent.train) logging.info("Starting training") for _ in range(num_iterations): start_time = time.time() data_collector.step() env_time_acc += time.time() - start_time train_time_start = time.time() for _ in range(train_steps_per_iteration): experience = replay_buffer.get_batch() train_loss = tf_agent.train(experience) total_loss = train_loss.loss train_time_acc += time.time() - train_time_start time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info("step = %d, loss = %f", global_step.numpy(), total_loss) combined_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc train_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / train_time_acc env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / env_time_acc logging.info( "%.3f combined steps / sec: %.3f env steps/sec, %.3f train steps/sec", combined_steps_per_sec, env_steps_per_sec, train_steps_per_sec, ) tf.compat.v2.summary.scalar( name="combined_steps_per_sec", data=combined_steps_per_sec, step=env_steps.result(), ) tf.compat.v2.summary.scalar( name="env_steps_per_sec", data=env_steps_per_sec, step=env_steps.result(), ) tf.compat.v2.summary.scalar( name="train_steps_per_sec", data=train_steps_per_sec, step=env_steps.result(), ) time_acc = 0 env_time_acc = 0 train_time_acc = 0 env_steps_before = env_steps.result().numpy() for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: for task in [None] + eval_task_list: with utils.FixedTask(eval_tf_env, task): prefix = "Metrics" if task is None else "Metrics-%s" % str( task) logging.info(prefix) metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix=prefix, ) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) return train_loss
def train_eval( root_dir, experiment_name, # experiment name env_name='carla-v0', agent_name='sac', # agent's name num_iterations=int(1e7), actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), model_network_ctor_type='non-hierarchical', # model net input_names=['camera', 'lidar'], # names for inputs mask_names=['birdeye'], # names for masks preprocessing_combiner=tf.keras.layers.Add(), # takes a flat list of tensors and combines them actor_lstm_size=(40,), # lstm size for actor critic_lstm_size=(40,), # lstm size for critic actor_output_fc_layers=(100,), # lstm output critic_output_fc_layers=(100,), # lstm output epsilon_greedy=0.1, # exploration parameter for DQN q_learning_rate=1e-3, # q learning rate for DQN ou_stddev=0.2, # exploration paprameter for DDPG ou_damping=0.15, # exploration parameter for DDPG dqda_clipping=None, # for DDPG exploration_noise_std=0.1, # exploration paramter for td3 actor_update_period=2, # for td3 # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=int(1e5), # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, initial_model_train_steps=100000, # initial model training batch_size=256, model_batch_size=32, # model training batch size sequence_length=4, # number of timesteps to train model actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, model_learning_rate=1e-4, # learning rate for model training td_errors_loss_fn=tf.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging num_images_per_summary=1, # images for each summary train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, gpu_allow_growth=True, # GPU memory growth gpu_memory_limit=None, # GPU memory limit action_repeat=1): # Name of single observation channel, ['camera', 'lidar', 'birdeye'] # Setup GPU gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit)]) # Get train and eval direction root_dir = os.path.expanduser(root_dir) root_dir = os.path.join(root_dir, env_name, experiment_name) # Get summary writers summary_writer = tf.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() # Eval metrics eval_metrics = [ tf_metrics.AverageReturnMetric( name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthEvalPolicy', buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() # Whether to record for summary with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create Carla environment if agent_name == 'latent_sac': py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names+mask_names, action_repeat=action_repeat) elif agent_name == 'dqn': py_env, eval_py_env = load_carla_env(env_name='carla-v0', discrete=True, obs_channels=input_names, action_repeat=action_repeat) else: py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names, action_repeat=action_repeat) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) fps = int(np.round(1.0 / (py_env.dt * action_repeat))) # Specs time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() ## Make tf agent if agent_name == 'latent_sac': # Get model network for latent sac if model_network_ctor_type == 'hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical elif model_network_ctor_type == 'non-hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical else: raise NotImplementedError model_net = model_network_ctor(input_names, input_names+mask_names) # Get the latent spec latent_size = model_net.latent_size latent_observation_spec = tensor_spec.TensorSpec((latent_size,), dtype=tf.float32) latent_time_step_spec = ts.time_step_spec(observation_spec=latent_observation_spec) # Get actor and critic net actor_net = actor_distribution_network.ActorDistributionNetwork( latent_observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (latent_observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) # Build the inner SAC agent based on latent space inner_agent = sac_agent.SacAgent( latent_time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) inner_agent.initialize() # Build the latent sac agent tf_agent = latent_sac_agent.LatentSACAgent( time_step_spec, action_spec, inner_agent=inner_agent, model_network=model_net, model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), model_batch_size=model_batch_size, num_images_per_summary=num_images_per_summary, sequence_length=sequence_length, gradient_clipping=gradient_clipping, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, fps=fps) else: # Set up preprosessing layers for dictionary observation inputs preprocessing_layers = collections.OrderedDict() for name in input_names: preprocessing_layers[name] = Preprocessing_Layer(32,256) if len(input_names) < 2: preprocessing_combiner = None if agent_name == 'dqn': q_rnn_net = q_rnn_network.QRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_rnn_net, epsilon_greedy=epsilon_greedy, n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=q_learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'ddpg' or agent_name == 'td3': actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) if agent_name == 'ddpg': tf_agent = ddpg_agent.DdpgAgent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) elif agent_name == 'td3': tf_agent = td3_agent.Td3Agent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'sac': actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=normal_projection_net) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_distribution_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, # make critic loss dimension compatible gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) else: raise NotImplementedError tf_agent.initialize() # Get replay buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, # No parallel environments max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] # Train metrics env_steps = tf_metrics.EnvironmentSteps() average_return = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] # Get policies # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) eval_policy = tf_agent.policy initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) collect_policy = tf_agent.collect_policy # Checkpointers train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=2) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'policy'), policy=eval_policy, global_step=global_step, max_to_keep=2) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() # Collect driver initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) # Optimize the performance by using tf functions initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if (env_steps.result() == 0 or replay_buffer.num_frames() == 0): logging.info( 'Initializing replay buffer by collecting experience for %d steps' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() if agent_name == 'latent_sac': compute_summaries( eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=1, num_episodes_to_render=1, model_net=model_net, fps=10, image_keys=input_names+mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=1, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Dataset generates trajectories with shape [Bxslx...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=sequence_length + 1).prefetch(3) iterator = iter(dataset) # Get train step def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) train_step = common.function(train_step) if agent_name == 'latent_sac': def train_model_step(): experience, _ = next(iterator) return tf_agent.train_model(experience) train_model_step = common.function(train_model_step) # Training initializations time_step = None time_acc = 0 env_steps_before = env_steps.result().numpy() # Start training for iteration in range(num_iterations): start_time = time.time() if agent_name == 'latent_sac' and iteration < initial_model_train_steps: train_model_step() else: # Run collect time_step, _ = collect_driver.run(time_step=time_step) # Train an iteration for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time # Log training information if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.summary.scalar( name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() # Get training metrics for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) # Evaluation if global_step.numpy() % eval_interval == 0: # Log evaluation metrics if agent_name == 'latent_sac': compute_summaries( eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, model_net=model_net, fps=10, image_keys=input_names+mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Save checkpoints global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name=ENV_NAME, num_iterations=ITERATIONS, fc_layer_params=LAYER_PARAMETERS, # Parameters for collect initial_collect_steps=COLLECT_STEPS, collect_steps_per_iteration=COLLECT_STEPS, epsilon_greedy=GREEDY, replay_buffer_capacity=BUFFER, # Parameters for target update target_update_tau=TAU, target_update_period=UPDATE_PERIOD, # Parameters for train train_steps_per_iteration=TRAIN_ITERATIONS, batch_size=BATCH_SIZE, learning_rate=LEARN_RATE, gamma=GAMMA, reward_scale_factor=SCALE, gradient_clipping=GRADIENT, # Parameters for eval num_eval_episodes=10, eval_interval=1000, # Parameters for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, agent_class=dqn_agent.DqnAgent, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_py_env = suite_gym.load(env_name) q_net = q_network.QNetwork(tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=fc_layer_params) tf_agent = agent_class( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] replay_observer = [replay_buffer.add_batch] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_policy = tf_agent.collect_policy collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) experience, _ = iterator.get_next() train_op = common.function(tf_agent.train)(experience=experience) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) global_step_call = sess.make_callable(global_step) train_step_call = sess.make_callable([train_op, summary_ops]) timed_at_step = global_step_call() collect_time = 0 train_time = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): # Train/collect/eval. start_time = time.time() collect_call() collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() train_time += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info('%.3f steps/sec', steps_per_sec) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def train_eval( root_dir, env_name='HalfCheetah-v2', eval_env_name=None, env_load_fn=suite_mujoco.load, num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] eval_summary_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_env_name = eval_env_name or env_name eval_py_env = env_load_fn(eval_env_name) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy( greedy_policy.GreedyPolicy(tf_agent.policy)) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] collect_policy = tf_agent.collect_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=5 * batch_size, num_steps=2).apply(tf.data.experimental.unbatch()).filter( _filter_invalid_transition).batch(batch_size).prefetch( batch_size * 5) dataset_iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) summary_ops = [] for train_metric in train_metrics: summary_ops.append(train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) global_step_val = sess.run(global_step) if global_step_val == 0: # Initial eval of randomly initialized policy metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): total_loss, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, train_sequence_length=1, # Params for QNetwork fc_layer_params=(100, ), # Params for QRnnNetwork input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) if train_sequence_length != 1 and n_step_update != 1: raise NotImplementedError( 'train_eval does not currently support n-step updates with stateful ' 'networks (i.e., RNNs)') action_spec = tf_env.action_spec() num_actions = action_spec.maximum - action_spec.minimum + 1 if train_sequence_length > 1: q_net = create_recurrent_network(input_fc_layer_params, lstm_size, output_fc_layer_params, num_actions) else: q_net = create_feedforward_network(fc_layer_params, num_actions) train_sequence_length = n_step_update # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # TODO(kbanoop): Figure out if it is possible to avoid the with block. with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_py_env = suite_gym.load(env_name) q_net = q_network.QNetwork(tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=fc_layer_params) tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate), # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839 epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec(), batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy()) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] global_step = tf.train.get_or_create_global_step() replay_observer = [replay_buffer.add_batch] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps).run() collect_policy = tf_agent.collect_policy() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = dataset.make_initializable_iterator() trajectories, _ = iterator.get_next() train_op = tf_agent.train(experience=trajectories, train_step_counter=global_step) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=tf.contrib.checkpoint.List(train_metrics)) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy(), global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) summary_op = tf.contrib.summary.all_summary_ops() with eval_summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) tf.contrib.summary.initialize(session=sess) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable( [train_op, summary_op, global_step]) timed_at_step = sess.run(global_step) collect_time = 0 train_time = 0 steps_per_second_ph = tf.placeholder(tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): # Train/collect/eval. start_time = time.time() collect_call() collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() train_time += time.time() - start_time if global_step_val % log_interval == 0: tf.logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) tf.logging.info('%.3f steps/sec' % steps_per_sec) tf.logging.info( 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def train_eval( root_dir, environment_name="broken_reacher", num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), initial_collect_steps=10000, real_initial_collect_steps=10000, collect_steps_per_iteration=1, real_collect_interval=10, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, classifier_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=False, train_on_real=False, delta_r_warmup=0, random_seed=0, checkpoint_dir=None, ): """A simple train and eval for SAC.""" np.random.seed(random_seed) tf.random.set_seed(random_seed) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, "train") eval_dir = os.path.join(root_dir, "eval") train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) if environment_name == "broken_reacher": get_env_fn = darc_envs.get_broken_reacher_env elif environment_name == "half_cheetah_obstacle": get_env_fn = darc_envs.get_half_cheetah_direction_env elif environment_name.startswith("broken_joint"): base_name = environment_name.split("broken_joint_")[1] get_env_fn = functools.partial(darc_envs.get_broken_joint_env, env_name=base_name) elif environment_name.startswith("falling"): base_name = environment_name.split("falling_")[1] get_env_fn = functools.partial(darc_envs.get_falling_env, env_name=base_name) else: raise NotImplementedError("Unknown environment: %s" % environment_name) eval_name_list = ["sim", "real"] eval_env_list = [get_env_fn(mode) for mode in eval_name_list] eval_metrics_list = [] for name in eval_name_list: eval_metrics_list.append([ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, name="AverageReturn_%s" % name), ]) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env_real = get_env_fn("real") if train_on_real: tf_env = get_env_fn("real") else: tf_env = get_env_fn("sim") time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=( tanh_normal_projection_network.TanhNormalProjectionNetwork), ) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer="glorot_uniform", last_kernel_initializer="glorot_uniform", ) classifier = classifiers.build_classifier(observation_spec, action_spec) tf_agent = darc_agent.DarcAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, classifier=classifier, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), classifier_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=classifier_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, ) replay_observer = [replay_buffer.add_batch] real_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, ) real_replay_observer = [real_replay_buffer.add_batch] sim_train_metrics = [ tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesSim"), tf_metrics.EnvironmentSteps(name="EnvironmentStepsSim"), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageReturnSim", ), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageEpisodeLengthSim", ), ] real_train_metrics = [ tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesReal"), tf_metrics.EnvironmentSteps(name="EnvironmentStepsReal"), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageReturnReal", ), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageEpisodeLengthReal", ), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup( sim_train_metrics + real_train_metrics, "train_metrics"), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "policy"), policy=eval_policy, global_step=global_step, ) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "replay_buffer"), max_to_keep=1, replay_buffer=(replay_buffer, real_replay_buffer), ) if checkpoint_dir is not None: checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) assert checkpoint_path is not None train_checkpointer._load_status = train_checkpointer._checkpoint.restore( # pylint: disable=protected-access checkpoint_path) train_checkpointer._load_status.initialize_or_restore() # pylint: disable=protected-access else: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if replay_buffer.num_frames() == 0: initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + sim_train_metrics, num_steps=initial_collect_steps, ) real_initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env_real, initial_collect_policy, observers=real_replay_observer + real_train_metrics, num_steps=real_initial_collect_steps, ) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + sim_train_metrics, num_steps=collect_steps_per_iteration, ) real_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env_real, collect_policy, observers=real_replay_observer + real_train_metrics, num_steps=collect_steps_per_iteration, ) config_str = gin.operative_config_str() logging.info(config_str) with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"), "w") as f: f.write(config_str) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) real_initial_collect_driver.run = common.function( real_initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) real_collect_driver.run = common.function(real_collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if replay_buffer.num_frames() == 0: logging.info( "Initializing replay buffer by collecting experience for %d steps with " "a random policy.", initial_collect_steps, ) initial_collect_driver.run() real_initial_collect_driver.run() for eval_name, eval_env, eval_metrics in zip(eval_name_list, eval_env_list, eval_metrics_list): metric_utils.eager_compute( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix="Metrics-%s" % eval_name, ) metric_utils.log_metrics(eval_metrics) time_step = None real_time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = (replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5)) real_dataset = (real_replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5)) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) real_iterator = iter(real_dataset) def train_step(): experience, _ = next(iterator) real_experience, _ = next(real_iterator) return tf_agent.train(experience, real_experience=real_experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) assert not policy_state # We expect policy_state == (). if (global_step.numpy() % real_collect_interval == 0 and global_step.numpy() >= delta_r_warmup): real_time_step, policy_state = real_collect_driver.run( time_step=real_time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info("step = %d, loss = %f", global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info("%.3f steps/sec", steps_per_sec) tf.compat.v2.summary.scalar(name="global_steps_per_sec", data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in sim_train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=sim_train_metrics[:2]) for train_metric in real_train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=real_train_metrics[:2]) if global_step_val % eval_interval == 0: for eval_name, eval_env, eval_metrics in zip( eval_name_list, eval_env_list, eval_metrics_list): metric_utils.eager_compute( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix="Metrics-%s" % eval_name, ) metric_utils.log_metrics(eval_metrics) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
env.close() ################################################################################ """ Another way to run the environment is with the use of TF agents. """ from tf_agents.environments import suite_gym from tf_agents.environments import tf_py_environment from tf_agents.policies import random_tf_policy py_env = suite_gym.load("gym_ctf:ctf-v0") env = tf_py_environment.TFPyEnvironment(py_env) # This creates a randomly initialized policy that the agent will follow. # Similar to just taking random actions in the environment. policy = random_tf_policy.RandomTFPolicy(env.time_step_spec(), env.action_spec()) time_step = env.reset() while not time_step.is_last(): action_step = policy.action(time_step) time_step = env.step(action_step.action) py_env.render('human') # Default for this render is rgb_array. py_env.close() ################################################################################
def simulate(): # Set up the environments for the agent to train and test its performance envTrain = ComputerSnake.Snake() envEval = ComputerSnake.Snake(persistence=True) # Convert and wrap in TFPyEnvironment training and evaluation environments train_env = tf_py_environment.TFPyEnvironment(envTrain) eval_env = tf_py_environment.TFPyEnvironment(envEval) # Set up q network with necessary parameters fc_layer_params = (100, ) q_net = q_network.QNetwork(train_env.observation_spec(), train_env.action_spec(), fc_layer_params=fc_layer_params) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) # look up train_step_counter = tf.Variable(0) # Set up and initialize the DQN learning agent. It takes in the time_step spec, # action spec, the q network, the optimizer, a loss function, and train_step_counter agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, # look up td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=train_step_counter) agent.initialize() # Set up policies the agent can use eval_policy = agent.policy collect_policy = agent.collect_policy # Policy which randomly selects actions for each step random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec()) #Buffer to store previous states replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=train_env.batch_size, max_length=replay_buffer_max_length) # Dataset generates trajectories with shape [Bx2x...] This is so that the agent has access to both the current # and previous state to compute loss. Parallel calls and prefetching are used to optimize process. dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) # (Optional) Optimize by wrapping some of the code in a graph using TF function. agent.train = common.function(agent.train) # Reset the train step agent.train_step_counter.assign(0) # Evaluate the agent's policy once before training. avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes) # We initially fill the replay buffer with 100 trajectories to help the assistant collect_data(train_env, random_policy, replay_buffer, steps=5000) train_env.reset() # Here, we run the simulation to train the agent scores_list = [] num_steps_arr = [] for currStep in range(num_iterations): # Collect a few steps using collect_policy and save to the replay buffer. for _ in range(collect_steps_per_iteration): collect_step(train_env, agent.collect_policy, replay_buffer) # Sample a batch of data from the buffer and update the agent's network. experience, unused_info = next(iterator) train_loss = agent.train(experience).loss # Number of training steps so far step = agent.train_step_counter.numpy() # Prints every 1000 steps made by the training agent if step % log_interval == 0: print('Moves made = {0}'.format(step)) # Evaluates the agent's policy every 5000 steps, prints results, # ands saves the results for later so they can be plotted if step % eval_interval == 0: avg_return = 0 for i in range(num_eval_episodes): curr_return = compute_avg_return(eval_env, agent.policy, 1) scores_list.append(curr_return) num_steps_arr.append(currStep) avg_return += curr_return avg_return = avg_return / num_eval_episodes print('step = {0}: Average Return = {1}'.format(step, avg_return)) plt.scatter(num_steps_arr, scores_list) plt.xlabel('Number of Steps Trained') plt.ylabel('Score') plt.title('Snake Reinforcement Learning') plt.show()
def train_eval( root_dir, env_name='gym_orbital_system:solarsystem-v0', eval_env_name=None, env_load_fn=suite_gym.load, # The SAC paper reported: # Hopper and Cartpole results up to 1000000 iters, # Humanoid results up to 10000000 iters, # Other mujoco tasks up to 3000000 iters. num_iterations=3000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py # HalfCheetah and Ant take 10000 initial collection steps. # Other mujoco tasks take 1000. # Different choices roughly keep the initial episodes about the same. initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=50000, policy_checkpoint_interval=50000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'td3_eval/train') eval_dir = os.path.join(root_dir, 'td3_eval/eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(eval_env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if replay_buffer.num_frames() == 0: # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps ' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) global_step_val = global_step.numpy() while global_step_val < num_iterations: start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step_val % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def main(): # Create train and evaluation environments for Tensorflow train_py_env = Environment.Environment() train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_py_env = Environment.Environment() eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) # utils.validate_py_environment(train_py_env, episodes=5) # Set up an agent # Decide on layers of a network fc_layer_params = (50, 200, 25, 6) #conv_layer_params = [(4, 4, 1), (8, 4, 2)] # QNetwork predicts QValues (expected returns) for all actions based on observation on the given environment q_net = q_network.QNetwork(train_env.observation_spec(), train_env.action_spec(), #conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params) # Initialize DQN Agent on the train environment steps, actions, QNetwork, Adam Optimizer, loss function & train step counter optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) # Variable maintains shared, persistent state manipulated by a program. # 0 is the initial value. # After construction, the type and shape of the variable are fixed. train_step_counter = tf.Variable(0) agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, epsilon_greedy=0.4, #TODO tune this td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=train_step_counter, #boltzmann_temperature=0.1, summarize_grads_and_vars=True) agent.initialize() # Policies """A policy defines the way an agent acts in an environment. Typically, the goal of RL is to train the underlying model until the policy produces the desired outcome. Agents contain two policies: agent.policy — The main policy that is used for evaluation and deployment. agent.collect_policy — A second policy that is used for data collection. """ # tf_agents.policies.random_tf_policy creates a policy which will randomly select an action for each time_step (independent of agent) random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec()) # Baseline average return of the moves based on random_policy (random actions of an agent) print(compute_avg_return(eval_env, random_policy, num_eval_episodes)) # Replay buffer # The replay buffer keeps track of data collected from the environment. # This tutorial uses tf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer, as it is the most common. # The constructor requires the specs for the data it will be collecting. # This is available from the agent using the collect_data_spec method. # The batch size and maximum buffer length are also required. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=1, max_length=replay_buffer_max_length, dataset_window_shift=1) # The agent needs access to the replay buffer. # This is provided by creating an iterable tf.data.Dataset pipeline which will feed data to the agent. # Each row of the replay buffer only stores a single observation step. # But since the DQN Agent needs both the current and next observation to compute the loss, # the dataset pipeline will sample two adjacent rows for each item in the batch (num_steps=2). # This dataset is also optimized by running parallel calls and prefetching data. dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, single_deterministic_pass=False, num_steps=2).prefetch(3) iterator = iter(dataset) # Train the agent agent.train = common.function(agent.train) # Reset the train step agent.train_step_counter.assign(0) collect_data(train_env, random_policy, replay_buffer, steps=10000) for _ in range(num_iterations): # Collect a few steps using collect_policy and save to the replay buffer. for _ in range(collect_steps_per_iteration): collect_step(train_env, agent.collect_policy, replay_buffer) # Sample a batch of data from the buffer and update the agent's network. experience, unused_info = next(iterator) train_loss = agent.train(experience).loss step = agent.train_step_counter.numpy() if step % log_interval == 0: avg_return = compute_avg_return(eval_env, agent.policy, 3) #TODO if not os.path.exists("eval_data"): os.makedirs("eval_data") path = os.path.join("eval_data", f'Eval_data.step{step // log_interval}.txt') with open(path, 'w') as f: for move in eval_py_env.all_moves: print(str(move), file=f) eval_py_env.all_moves = [] print('step = {0}: loss = {1}, Average Return: {2}'.format(step, train_loss, avg_return))
def test(binance, model): symbol = "BTCUSDT" df = pd.read_csv("..\\Data\\" + symbol + "_data.csv", index_col=0, parse_dates=True) df = make_dataset.make_reinforcement_dataset(df) train_env_py = TradingEnv(df) train_env_py.set_init_balance(1000) eval_env_py = TradingEnv(df) eval_env_py.set_init_balance(1000) #train_env_py = suite_gym.load('CartPole-v0') #eval_env_py = suite_gym.load('CartPole-v0') train_env = tf_py_environment.TFPyEnvironment(train_env_py) eval_env = tf_py_environment.TFPyEnvironment(eval_env_py) #q_net = q_rnn_network.QRnnNetwork( # train_env.observation_spec(), # train_env.action_spec(), # input_fc_layer_params=(128,64,16), # output_fc_layer_params=(128,64,16), # lstm_size=(128,64,16)) q_net = q_network.QNetwork(train_env.observation_spec(), train_env.action_spec(), fc_layer_params=(256, 128, 64, 32, 16)) global_step = tf.Variable(0, name="global_step", trainable=False) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001) agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=global_step, epsilon_greedy=0.2) agent.initialize() eval_policy = agent.policy collect_policy = agent.collect_policy random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec()) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=train_env.batch_size, max_length=replay_buffer_max_length) collect_data(train_env, random_policy, replay_buffer, steps=10000) policy_checkpointer = common.Checkpointer( ckpt_dir="ReinforcementLearnData/Checkpoint", agent=agent, policy=agent.policy, replay_buffer=replay_buffer, global_step=global_step) policy_checkpointer.initialize_or_restore() tf_policy_saver = policy_saver.PolicySaver(agent.policy) # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) print("1:Train 2:Evaluate 3:Simulation") mode = input(">") if mode == "1": pass elif mode == "2": evaluate(agent, eval_env) return elif mode == "3": simulation(binance, agent) return # (Optional) Optimize by wrapping some of the code in a graph using TF function. agent.train = common.function(agent.train) # Evaluate the agent's policy once before training. avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes) returns = [avg_return] for _ in range(num_iterations): # Collect a few steps using collect_policy and save to the replay buffer. for _ in range(collect_steps_per_iteration): collect_step(train_env, agent.collect_policy, replay_buffer) # Sample a batch of data from the buffer and update the agent's network. experience, unused_info = next(iterator) train_loss = agent.train(experience).loss step = agent.train_step_counter.numpy() if step % log_interval == 0: print("step = {0}: loss = {1}".format(step, train_loss)) if step % eval_interval == 0: avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes) print("step = {0}: Average Return = {1}".format(step, avg_return)) returns.append(avg_return) #save agent policy_checkpointer.save(global_step) tf_policy_saver.save("ReinforcementLearnData/Policy") x = range(0, num_iterations + 1, eval_interval) plt.plot(x, returns) plt.ylabel('Average Return') plt.xlabel('Iterations') plt.show() print("Result:") print(compute_avg_return(eval_env, agent.policy, num_eval_episodes))
def train_level(level, consecutive_wins_flag=5, collect_random_steps=True, max_iterations=num_iterations): """ create DQN agent to train a level of the game :param level: level of the game :param consecutive_wins_flag: number of consecutive wins in evaluation signifying the training is done :param collect_random_steps: whether to collect random steps at the beginning, always set to 'True' when the global step is 0. :param max_iterations: stop the training when it reaches the max iteration regardless of the result """ global saving_time cells = query_level(level) size = len(cells) env = tf_py_environment.TFPyEnvironment(GameEnv(size, cells)) eval_env = tf_py_environment.TFPyEnvironment(GameEnv(size, cells)) optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) fc_layer_params = (neuron_num_mapper[size], ) q_net = q_network.QNetwork(env.observation_spec()[0], env.action_spec(), fc_layer_params=fc_layer_params, activation_fn=tf.keras.activations.relu) global_step = tf.compat.v1.train.get_or_create_global_step() agent = dqn_agent.DdqnAgent( env.time_step_spec(), env.action_spec(), q_network=q_net, optimizer=optimizer, td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=global_step, observation_and_action_constraint_splitter=GameEnv. obs_and_mask_splitter) agent.initialize() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=env.batch_size, max_length=replay_buffer_max_length) # drivers collect_driver = dynamic_step_driver.DynamicStepDriver( env, policy=agent.collect_policy, observers=[replay_buffer.add_batch], num_steps=collect_steps_per_iteration) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] eval_driver = dynamic_episode_driver.DynamicEpisodeDriver( eval_env, policy=agent.policy, observers=eval_metrics, num_episodes=num_eval_episodes) # checkpointer of the replay buffer and policy train_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( dir_path, 'trained_policies/train_lv{0}'.format(level)), max_to_keep=1, agent=agent, policy=agent.policy, global_step=global_step, replay_buffer=replay_buffer) # policy saver tf_policy_saver = policy_saver.PolicySaver(agent.policy) train_checkpointer.initialize_or_restore() # optimize by wrapping some of the code in a graph using TF function agent.train = common.function(agent.train) collect_driver.run = common.function(collect_driver.run) eval_driver.run = common.function(eval_driver.run) # collect initial replay data if collect_random_steps: initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec=env.time_step_spec(), action_spec=env.action_spec(), observation_and_action_constraint_splitter=GameEnv. obs_and_mask_splitter) dynamic_step_driver.DynamicStepDriver( env, initial_collect_policy, observers=[replay_buffer.add_batch], num_steps=initial_collect_steps).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) # train the model until 5 consecutive evaluation have reward greater than 100 consecutive_eval_win = 0 train_iterations = 0 while consecutive_eval_win < consecutive_wins_flag and train_iterations < max_iterations: collect_driver.run() for _ in range(collect_steps_per_iteration): experience, _ = next(iterator) train_loss = agent.train(experience).loss # evaluate the training at intervals step = global_step.numpy() if step % eval_interval == 0: eval_driver.run() average_return = eval_metrics[0].result().numpy() average_len = eval_metrics[1].result().numpy() print("level: {0} step: {1} AverageReturn: {2} AverageLen: {3}". format(level, step, average_return, average_len)) # evaluate consecutive wins if average_return > 10: consecutive_eval_win += 1 else: consecutive_eval_win = 0 if step % save_interval == 0: start = time.time() train_checkpointer.save(global_step=step) saving_time += time.time() - start train_iterations += 1 # save the policy train_checkpointer.save(global_step=global_step.numpy()) tf_policy_saver.save( os.path.join(dir_path, 'trained_policies/policy_lv{0}'.format(level)))
def main( grid_size=options.grid_size, sensing_locations_amount=options.sensing_locations_amount, drones_amount=options.drones_amount, drone_max_speed=options.drone_max_speed, drone_bandwidth=options.drone_bandwidth, total_location_data=options.total_location_data, cycles_num=options.cycles_num, random=options.random, ): sensing_locations = (np.random.rand(sensing_locations_amount, 2) * grid_size) - ( grid_size / 2 ) options.sensing_locations = sensing_locations options.sensing_locations_amount = sensing_locations_amount options.grid_size = grid_size options.drones_amount = drones_amount options.drone_max_speed = drone_max_speed options.drone_bandwidth = drone_bandwidth options.cycles_num = cycles_num options.random = random # -- BEGIN of the jupyter notebook's code (please refer to it for a more readable version) # Link to the online notebook: https://github.com/AlessioLuciani/distributed-uav-rl-protocol/blob/main/simulation.ipynb options.drones_locations = np.zeros( (options.drones_amount, 2), dtype=float) options.sensing_data_amounts = np.zeros(options.drones_amount, dtype=float) options.aois = np.zeros(options.sensing_locations_amount, dtype=int) options.chosen_locations = np.zeros(options.drones_amount, dtype=int) options.cycle_stages = np.zeros(options.drones_amount, dtype=int) options.data_transmission_cycle = options.drone_bandwidth * options.cycle_length options.aois_vec = np.zeros( (cycles_num, options.sensing_locations_amount), dtype=int ) options.drones_vec = np.zeros( (cycles_num, options.drones_amount, 2), dtype=float) options.chosen_loc_vec = np.zeros( ((cycles_num, options.drones_amount)), dtype=int) options.cycle_stages_vec = np.zeros( (cycles_num, options.drones_amount), dtype=int) class DurpEnv(py_environment.PyEnvironment): def __init__(self, drone): self._drone = drone self._action_spec = array_spec.BoundedArraySpec( shape=(), dtype=np.int32, minimum=0, maximum=options.sensing_locations_amount - 1, name="action", ) self._observation_spec = array_spec.BoundedArraySpec( shape=(2,), minimum=(-options.grid_size / 2), maximum=(options.grid_size / 2), dtype=np.float64, ) def action_spec(self): return self._action_spec def observation_spec(self): return self._observation_spec def _reset(self): return ts.restart(np.array([0.0, 0.0], dtype=np.float64)) def _step(self, action): chosen_location_index = int(action) accumulated_aoi = get_accumulated_aoi(options.current_cycle[0]) options.aois[chosen_location_index] = options.current_cycle[0] new_accumulated_aoi = get_accumulated_aoi(options.current_cycle[0]) aoi_multiplier = 0.05 normalized_diff_aoi_component = ( ( 1 / ( 1 + np.exp( -(accumulated_aoi - new_accumulated_aoi) * aoi_multiplier ) ) ) - 0.5 ) * 2.0 chosen_location = options.sensing_locations[chosen_location_index] drone_location = options.drones_locations[self._drone] distance = np.linalg.norm(chosen_location - drone_location) distance_multiplier = 0.03 normalized_location_distance = ( 1 - (1 / (1 + np.exp(-distance * distance_multiplier))) ) * 2.0 aoi_weight = 0.7 distance_weight = 0.3 reward = ( normalized_diff_aoi_component * aoi_weight + normalized_location_distance * distance_weight + np.random.random() * (1.0 - aoi_weight - distance_weight) ) return ts.transition(chosen_location, reward=reward) learning_rate = 0.001 optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) environments = [] agents = [] for drone in range(options.drones_amount): durp_env = DurpEnv(drone) train_env = tf_py_environment.TFPyEnvironment(durp_env) q_net = q_network.QNetwork( train_env.observation_spec(), train_env.action_spec() ) train_step_counter = tf.Variable(0) agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, train_step_counter=train_step_counter, ) agent.initialize() agents.append(agent) environments.append(train_env) num_iterations = 5 intermediate_iterations = 5 eval_interval = 10 initial_collect_steps = 1 collect_steps_per_iteration = 1 batch_size = 64 def compute_avg_return(environment, policy, num_episodes=10): total_return = 0.0 for _ in range(num_episodes): time_step = environment.reset() episode_return = 0.0 for m in range(10): action_step = policy.action(time_step) time_step = environment.step(action_step.action) episode_return += time_step.reward total_return += episode_return avg_return = total_return / num_episodes return avg_return.numpy()[0] def collect_step(environment, policy, buffer, drone): time_step = environment.current_time_step() action_step = policy.action(time_step) next_time_step = environment.step(action_step.action) options.drones_locations[drone] = next_time_step.observation traj = trajectory.from_transition( time_step, action_step, next_time_step) buffer.add_batch(traj) def collect_data(env, policy, buffer, steps, drone): for step in range(1, steps + 1): options.current_cycle[0] = step collect_step(env, policy, buffer, drone) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agents[0].collect_data_spec, batch_size=environments[0].batch_size ) random_policy = random_tf_policy.RandomTFPolicy( environments[0].time_step_spec(), environments[0].action_spec() ) reset_aois() reset_drones_locations() collect_data( environments[0], random_policy, replay_buffer, initial_collect_steps, 0 ) reset_aois() reset_drones_locations() dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2 ).prefetch(3) iterator = iter(dataset) # Reset the train step returns = np.zeros( (options.drones_amount, (num_iterations // eval_interval) + 1), dtype=np.float64 ) for k in range(len(agents)): avg_return = compute_avg_return(environments[k], agents[k].policy) returns[k][0] = avg_return for i in range(num_iterations): reset_aois() reset_drones_locations() if i % 100 == 0: print("----------------", i) for j in range(intermediate_iterations): for k in range(len(agents)): agent = agents[k] env = environments[k] # Collect a few steps using collect_policy and save to the replay buffer. collect_data( env, agent.collect_policy, replay_buffer, collect_steps_per_iteration, k, ) for k in range(len(agents)): agent = agents[k] # Sample a batch of data from the buffer and update the agent's network. experience, unused_info = next(iterator) train_loss = agent.train(experience).loss print("Loss:", train_loss) if i % eval_interval == 0: for k in range(len(agents)): agent = agents[k] avg_return = compute_avg_return(environments[k], agent.policy) returns[k][(num_iterations // eval_interval)] = avg_return time_steps = [] for drone in range(options.drones_amount): time_steps.append(environments[drone].reset()) reset_aois() reset_drones_locations() # -- END of the jupyter notebook's code # -- Start of the simulation for cycle in range(1, options.cycles_num + 1): if cycle % 10 == 1: print(get_accumulated_aoi(cycle)) print(options.aois) print("----------------") for drone in range(options.drones_amount): if options.cycle_stages[drone] == 0: agent = agents[drone] env = environments[drone] chosen_location_index = -1 if random: chosen_location_index = randrange( options.sensing_locations_amount) options.aois[chosen_location_index] = cycle else: options.current_cycle[0] = cycle policy_step = agent.policy.action(time_steps[drone]).replace( action=tf.constant( [get_best_action(drone)], dtype=np.int32) ) new_step = env.step(policy_step.action) time_steps[drone] = new_step chosen_location_index = int(policy_step.action) options.chosen_locations[drone] = chosen_location_index options.cycle_stages[drone] = 1 elif ( options.drones_locations[drone] != options.sensing_locations[options.chosen_locations[drone]] ).all(): traj = get_trajectory(drone) new_location = options.drones_locations[drone] + traj options.drones_locations[drone] = new_location elif options.sensing_data_amounts[drone] == 0.0: options.cycle_stages[drone] = 2 options.sensing_data_amounts[drone] = options.total_location_data options.cycle_stages[drone] = 3 else: options.sensing_data_amounts[drone] = np.max( [ options.sensing_data_amounts[drone] - options.data_transmission_cycle, 0.0, ] ) if options.sensing_data_amounts[drone] == 0.0: options.cycle_stages[drone] = 0 options.cycle_stages_vec[cycle - 1] = options.cycle_stages options.aois_vec[cycle - 1] = [ get_location_aoi(cycle, index) for index in range(options.sensing_locations_amount) ] options.drones_vec[cycle - 1] = options.drones_locations options.chosen_loc_vec[cycle - 1] = options.chosen_locations app = QApplication(sys.argv) screen_w, screen_h = get_screen_resolution(app) simulator = Simulator(screen_w, screen_h) simulator.show() sys.exit(app.exec())
def train_eval( root_dir, env_name='gym_solventx-v0', eval_env_name='gym_solventx-v0', env_load_fn=suite_gym.load, num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), num_parallel_environments=1, # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=_DEFAULT_REWARD_SCALE, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC on Mujoco. All hyperparameters come from the original SAC paper (https://arxiv.org/pdf/1801.01290.pdf). """ if reward_scale_factor == _DEFAULT_REWARD_SCALE: # Use value recommended by https://arxiv.org/abs/1801.01290 if env_name.startswith('Humanoid'): reward_scale_factor = 20.0 else: reward_scale_factor = 5.0 root_dir = os.path.expanduser(root_dir) summary_writer = tf.compat.v2.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # create training environment if num_parallel_environments == 1: py_env = env_load_fn(env_name) else: py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment(py_env) # create evaluation environment eval_env_name = eval_env_name or env_name eval_py_env = env_load_fn(eval_env_name) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] env_steps = tf_metrics.EnvironmentSteps(prefix='Train') average_return = tf_metrics.AverageReturnMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(prefix='Train'), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, _RUN_DIR, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, _RUN_DIR, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, _RUN_DIR, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if env_steps.result() == 0 or replay_buffer.num_frames() == 0: logging.info( 'Initializing replay buffer by collecting experience for %d steps' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.result()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) time_acc = 0 env_steps_before = env_steps.result().numpy() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.compat.v2.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.result()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_implementation(self, train_context: core.TrainContext): """Tf-Agents Ppo Implementation of the train loop. The implementation follows https://colab.research.google.com/github/tensorflow/agents/blob/master/tf_agents/colabs/1_dqn_tutorial.ipynb """ assert isinstance(train_context, core.StepsTrainContext) dc: core.StepsTrainContext = train_context train_env = self._create_env(discount=dc.reward_discount_gamma) observation_spec = train_env.observation_spec() action_spec = train_env.action_spec() timestep_spec = train_env.time_step_spec() # SetUp Optimizer, Networks and DqnAgent self.log_api('AdamOptimizer', '()') optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=dc.learning_rate) self.log_api('QNetwork', '()') q_net = q_network.QNetwork(observation_spec, action_spec, fc_layer_params=self.model_config.fc_layers) self.log_api('DqnAgent', '()') tf_agent = dqn_agent.DqnAgent( timestep_spec, action_spec, q_network=q_net, optimizer=optimizer, td_errors_loss_fn=common.element_wise_squared_loss) self.log_api('tf_agent.initialize', f'()') tf_agent.initialize() self._trained_policy = tf_agent.policy # SetUp Data collection & Buffering self.log_api('TFUniformReplayBuffer', '()') replay_buffer = TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=train_env.batch_size, max_length=dc.max_steps_in_buffer) self.log_api('RandomTFPolicy', '()') random_policy = random_tf_policy.RandomTFPolicy( timestep_spec, action_spec) self.log_api('replay_buffer.add_batch', '(trajectory)') for _ in range(dc.num_steps_buffer_preload): self.collect_step(env=train_env, policy=random_policy, replay_buffer=replay_buffer) # Train tf_agent.train = common.function(tf_agent.train, autograph=False) self.log_api( 'replay_buffer.as_dataset', f'(num_parallel_calls=3, ' + f'sample_batch_size={dc.num_steps_sampled_from_buffer}, num_steps=2).prefetch(3)' ) dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=dc.num_steps_sampled_from_buffer, num_steps=2).prefetch(3) iter_dataset = iter(dataset) self.log_api('for each iteration') self.log_api(' replay_buffer.add_batch', '(trajectory)') self.log_api(' tf_agent.train', '(experience=trajectory)') while True: self.on_train_iteration_begin() for _ in range(dc.num_steps_per_iteration): self.collect_step(env=train_env, policy=tf_agent.collect_policy, replay_buffer=replay_buffer) trajectories, _ = next(iter_dataset) tf_loss_info = tf_agent.train(experience=trajectories) self.on_train_iteration_end(tf_loss_info.loss) if train_context.training_done: break return