def evaluate(self): """Evaluates the agent """ global_iteration = self._agent._agent._train_step_counter.numpy() logger.info( "Evaluating the agent's performance in {} episodes.".format( str(self._params["ML"]["Runner"]["evaluation_steps"]))) metric_utils.eager_compute( self._eval_metrics, self._runtime, self._agent._agent.policy, num_episodes=self._params["ML"]["Runner"]["evaluation_steps"]) metric_utils.log_metrics(self._eval_metrics) tf.summary.scalar("mean_reward", self._eval_metrics[0].result().numpy(), step=global_iteration) tf.summary.scalar("mean_steps", self._eval_metrics[1].result().numpy(), step=global_iteration) logger.error( "The agent achieved on average {} reward and {} steps in \ {} episodes." \ .format(str(self._eval_metrics[0].result().numpy()), str(self._eval_metrics[1].result().numpy()), str(self._params["ML"]["Runner"]["evaluation_steps"])))
def Evaluate(self): self._agent._training = False global_iteration = self._agent._agent._train_step_counter.numpy() self._logger.info( "Evaluating the agent's performance in {} episodes.".format( str(self._params["ML"]["TFARunner"]["EvaluationSteps", "", 20]))) metric_utils.eager_compute( self._eval_metrics, self._wrapped_env, self._agent._agent.policy, num_episodes=self._params["ML"]["TFARunner"]["EvaluationSteps", "", 20]) metric_utils.log_metrics(self._eval_metrics) tf.summary.scalar("mean_reward", self._eval_metrics[0].result().numpy(), step=global_iteration) tf.summary.scalar("mean_steps", self._eval_metrics[1].result().numpy(), step=global_iteration) self._logger.info( "The agent achieved on average {} reward and {} steps in \ {} episodes." \ .format(str(self._eval_metrics[0].result().numpy()), str(self._eval_metrics[1].result().numpy()), str(self._params["ML"]["TFARunner"]["EvaluationSteps", "", 20])))
def train_agent(n_iterations): time_step = None policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size) iterator = iter(dataset) for iteration in range(n_iterations): time_step, policy_state = collect_driver.run(time_step, policy_state) trajectories, buffer_info = next(iterator) train_loss = agent.train(trajectories) print("\r{} loss:{:.5f}".format(iteration, train_loss.loss.numpy()), end="") if iteration % config.TRAINING_LOG_INTERVAL == 0: utils.print_time_stats(train_start, iteration) print("\r") log_metrics(train_metrics) if iteration % config.TRAINING_SAVE_POLICY_INTERVAL == 0: save_agent_policy() if iteration % config.TRAINING_LOG_MEASURES_INTERVAL == 0: # calculate and report the total return over 1 episode utils.write_summary("AverageReturnMetric", train_metrics[2].result(), iteration) utils.write_summary("AverageEpisodeLengthMetric", train_metrics[3].result(), iteration) utils.writer.flush() save_agent_policy() utils.writer.flush()
def train_agent(n_iterations): time_step = None policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size) iterator = iter(dataset) for iteration in range(n_iterations): time_step, policty_state = collect_driver.run(time_step, policy_state) trajectories, buffer_info = next(iterator) train_loss = agent.train(trajectories) print("\r{} loss:{:.5f}".format( agent.train_step_counter.value().numpy(), train_loss.loss.numpy()), end="") if iteration % ITER_LOG_PERIOD == 0: fp.write(", ".join( ["{}".format(agent.train_step_counter.value().numpy())] + ["{}".format(m.result()) for m in training_metrics]) + "\n") if iteration % ITER_NEWLINE_PERIOD == 0: print() log_metrics(training_metrics + training_metrics_2) fp.flush() if iteration and iteration % ITER_CHECKPOINT_PERIOD == 0: train_checkpointer.save(train_step) # tf_policy_saver.save(POLICY_SAVE_DIR) print()
def train_agent(n_iterations): current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") train_log_dir = 'logs/dqn_agent/' + current_time + '/train' writer = tf.summary.create_file_writer(train_log_dir) time_step = None policy_state = agent.collect_policy.get_initial_state(train_env.batch_size) iterator = iter(dataset) with writer.as_default(): for iteration in range(n_iterations): time_step, policy_state = collect_driver.run( time_step, policy_state) trajectories, buffer_info = next(iterator) train_loss = agent.train(trajectories) #log metrics print("\r{} loss:{:.5f}".format(iteration, train_loss.loss.numpy()), end="") if iteration % 1000 == 0: log_metrics(train_metrics) tf.summary.scalar("number_of_episodes", train_metrics[0].result(), iteration) tf.summary.scalar("environment_steps", train_metrics[1].result(), iteration) tf.summary.scalar("average_return", train_metrics[2].result(), iteration) tf.summary.scalar("average_episode_length", train_metrics[3].result(), iteration) writer.flush()
def train_agent(n_iterations): time_step = None policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size) iterator = iter(dataset) for iteration in range(n_iterations): time_step, policy_state = collect_driver.run(time_step, policy_state) trajectories, buffer_info = next(iterator) train_loss = agent.train(trajectories) print("\r{} loss:{:.5f}".format( iteration, train_loss.loss.numpy()), end="") if iteration % 1000 == 0: log_metrics(train_metrics)
def compute(): results = metric_utils.eager_compute(eval_metrics, eval_tf_env, eval_policy, eval_num_episodes, train_step, eval_summary_writer, 'Metrics - Evaluation') # result = metric_utils.compute( # eval_metrics, # eval_tf_env, # eval_policy, # eval_num_episodes # ) metric_utils.log_metrics(eval_metrics)
def _eval(self): global_step = get_global_counter() with tf.summary.record_if(True): eager_compute( metrics=self._eval_metrics, environment=self._eval_env, state_spec=self._algorithm.predict_state_spec, action_fn=lambda time_step, state: common.algorithm_step( algorithm_step_func=self._algorithm.greedy_predict, time_step=self._algorithm.transform_timestep(time_step), state=state), num_episodes=self._num_eval_episodes, step_metrics=self._driver.get_step_metrics(), train_step=global_step, summary_writer=self._eval_summary_writer, summary_prefix="Metrics") metric_utils.log_metrics(self._eval_metrics)
def eval_single(self, checkpoint_path): """Evaluate a single checkpoint.""" global_step_val = self._reload_agent(checkpoint_path) results = metric_utils.eager_compute( self._eval_metrics, self._eval_tf_env, self._eval_policy, num_episodes=self._num_eval_episodes, train_step=global_step_val, summary_writer=self._eval_summary_writer, summary_prefix='Metrics', ) if self._eval_metrics_callback is not None: self._eval_metrics_callback(results, global_step_val) metric_utils.log_metrics(self._eval_metrics) if self._eval_summary_writer: self._eval_summary_writer.flush() return global_step_val
def train_agent(n_iterations): saver = PolicySaver(agent.policy, batch_size=tf_env.batch_size) time_step = None policy_state = agent.collect_policy.get_initial_state( tf_env.batch_size) iterator = iter(dataset) for iteration in tqdm(range(n_iterations)): time_step, policy_state = collect_driver.run( time_step, policy_state) trajectories, buffer_info = next(iterator) train_loss = agent.train(trajectories) if iteration % 1000 == 0: print("\r{} loss:{:.5f}".format(iteration, train_loss.loss.numpy()), end="") log_metrics(train_metrics) # save the policy each 10K iteration if iteration % 100000 == 0: saver.save('policy_%d' % iteration)
def evaluate(): # Override outer record_if that may be out of sync with respect to the # env_steps.result() value used for the summay step. with tf.compat.v2.summary.record_if(True): qj(g_step.numpy(), 'Starting eval at step', tic=1) results = pisac_metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, histograms=eval_histograms, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', use_function=drivers_in_graph, ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.result()) tfa_metric_utils.log_metrics(eval_metrics) qj(s='Finished eval', toc=1)
def train_agent(n_iterations): checkpoint.restore(manager.latest_checkpoint) if manager.latest_checkpoint: print("Restored from {}".format(manager.latest_checkpoint)) else: print("Initializing from scratch.") time_step = None policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size) iterator = iter(dataset) for iteration in range(n_iterations): time_step, policy_state = collect_driver.run(time_step, policy_state) trajectories, buffer_info = next(iterator) train_loss = agent.train(trajectories) print("\r{} loss:{:.5f}".format(iteration, train_loss.loss.numpy()), end="") if iteration % 1000 == 0: log_metrics(train_metrics) if iteration % 5000 == 0: checkpoint.save(file_prefix=checkpoint_prefix)
def train_agent(n_iterations): time_step = None policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size) iterator = iter(dataset) for iteration in range(initial_policy, n_iterations): time_step, policy_state = collect_driver.run(time_step, policy_state) trajectories, buffer_info = next(iterator) train_loss = agent.train(trajectories) print("\r{} loss:{:.5f} done:{:.5f}".format( iteration, train_loss.loss.numpy(), iteration / n_iterations * 100.0), end="") if iteration % 1000 == 0: log_metrics(train_metrics) if iteration % 10000 == 0 and iteration > 0: #keras.saved_model.saved_model(my_policy, 'policy_' + str(iteration)) #tf.saved_model.save(agent, 'policy_' + str(iteration)) my_policy = agent.policy saver = PolicySaver(my_policy) saver.save('policy_' + str(iteration))
def DDPG_Bipedal(root_dir): # Setting up directories for results root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train' + '/' + str(run_id)) eval_dir = os.path.join(root_dir, 'eval' + '/' + str(run_id)) vid_dir = os.path.join(root_dir, 'vid' + '/' + str(run_id)) # Set up Summary writer for training and evaluation train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ # Metric to record average return tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), # Metric to record average episode length tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] #Create global step global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Load Environment with different wrappers tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) eval_py_env = suite_gym.load(env_name) # Define Actor Network actorNN = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=(400, 300), ) # Define Critic Network NN_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) criticNN = critic_network.CriticNetwork( NN_input_specs, observation_fc_layer_params=(400, ), action_fc_layer_params=None, joint_fc_layer_params=(300, ), ) # Define & initialize DDPG Agent agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actorNN, critic_network=criticNN, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=gamma, train_step_counter=global_step) agent.initialize() # Determine which train metrics to display with summary writer train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Set policies for evaluation, initial collection eval_policy = agent.policy # Actor policy collect_policy = agent.collect_policy # Actor policy with OUNoise # Set up replay buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) # Define driver for initial replay buffer filling initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, # Initializes with random Parameters observers=[replay_buffer.add_batch], num_steps=initial_collect_steps) # Define collect driver for collect steps per iteration collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) agent.train = common.function(agent.train) # Make 1000 random steps in tf_env and save in Replay Buffer logging.info( 'Initializing replay buffer by collecting experience for 1000 steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() # Computes Evaluation Metrics results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset outputs steps in batches of 64 dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=64, num_steps=2).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next( iterator) #Get experience from dataset (replay buffer) return agent.train(experience) #Train agent on that experience if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() # Get start time # Collect data for replay buffer time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) # Train on experience for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='iterations_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) if results['AverageReturn'].numpy() >= 230.0: video_score = create_video(video_dir=vid_dir, env_name="BipedalWalker-v2", vid_policy=eval_policy, video_id=global_step.numpy()) return train_loss
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_whitelist='position', num_iterations=100000, actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), # Params for collect initial_collect_episodes=1, collect_episodes_per_iteration=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train # Params for train train_steps_per_iteration=200, batch_size=64, train_sequence_length=10, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=None, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=True, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_whitelist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_whitelist=[observations_whitelist]) ] else: env_wrappers = [] tf_env = tf_py_environment.TFPyEnvironment( suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers)) actor_net = actor_rnn_network.ActorRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_rnn_network.CriticRnnNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d episodes ' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience, train_step_counter=global_step) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, train_sequence_length=1, # Params for QNetwork fc_layer_params=(100, ), # Params for QRnnNetwork input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) if train_sequence_length != 1 and n_step_update != 1: raise NotImplementedError( 'train_eval does not currently support n-step updates with stateful ' 'networks (i.e., RNNs)') if train_sequence_length > 1: q_net = q_rnn_network.QRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) else: q_net = q_network.QNetwork(tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=fc_layer_params) train_sequence_length = n_step_update # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, random_seed=None, env_name='sawyer_push', eval_env_name=None, env_load_fn=get_env, max_episode_steps=1000, eval_episode_steps=1000, # The SAC paper reported: # Hopper and Cartpole results up to 1000000 iters, # Humanoid results up to 10000000 iters, # Other mujoco tasks up to 3000000 iters. num_iterations=3000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py # HalfCheetah and Ant take 10000 initial collection steps. # Other mujoco tasks take 1000. # Different choices roughly keep the initial episodes about the same. initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train reset_goal_frequency=1000, # virtual episode size for reset-free training train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, # reset-free parameters use_minimum=True, reset_lagrange_learning_rate=3e-4, value_threshold=None, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, # Td3 parameters actor_update_period=1, exploration_noise_std=0.1, target_policy_noise=0.1, target_policy_noise_clip=0.1, dqda_clipping=None, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, # video recording for the environment video_record_interval=10000, num_videos=0, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): start_time = time.time() root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') video_dir = os.path.join(eval_dir, 'videos') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) env, env_train_metrics, env_eval_metrics, aux_info = env_load_fn( name=env_name, max_episode_steps=None, gym_env_wrappers=(functools.partial( reset_free_wrapper.ResetFreeWrapper, reset_goal_frequency=reset_goal_frequency, full_reset_frequency=max_episode_steps),)) tf_env = tf_py_environment.TFPyEnvironment(env) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(name=eval_env_name, max_episode_steps=eval_episode_steps)[0]) eval_metrics += env_eval_metrics time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() if FLAGS.agent_type == 'sac': actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=functools.partial( tanh_normal_projection_network.TanhNormalProjectionNetwork, std_transform=std_clip_transform), name='forward_actor') critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform', name='forward_critic') tf_agent = SacAgent( time_step_spec, action_spec, num_action_samples=FLAGS.num_action_samples, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, name='forward_agent') actor_net_rev = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=functools.partial( tanh_normal_projection_network.TanhNormalProjectionNetwork, std_transform=std_clip_transform), name='reverse_actor') critic_net_rev = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform', name='reverse_critic') tf_agent_rev = SacAgent( time_step_spec, action_spec, num_action_samples=FLAGS.num_action_samples, actor_network=actor_net_rev, critic_network=critic_net_rev, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, name='reverse_agent') elif FLAGS.agent_type == 'td3': actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, target_policy_noise=target_policy_noise, target_policy_noise_clip=target_policy_noise_clip, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() tf_agent_rev.initialize() if FLAGS.use_reset_goals: # distance to initial state distribution initial_state_distance = state_distribution_distance.L2Distance( initial_state_shape=aux_info['reset_state_shape']) initial_state_distance.update( tf.constant(aux_info['reset_states'], dtype=tf.float32), update_type='complete') if use_tf_functions: initial_state_distance.distance = common.function( initial_state_distance.distance) tf_agent.compute_value = common.function(tf_agent.compute_value) # initialize reset / practice goal proposer if reset_lagrange_learning_rate > 0: reset_goal_generator = ResetGoalGenerator( goal_dim=aux_info['reset_state_shape'][0], num_reset_candidates=FLAGS.num_reset_candidates, compute_value_fn=tf_agent.compute_value, distance_fn=initial_state_distance, use_minimum=use_minimum, value_threshold=value_threshold, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=reset_lagrange_learning_rate), name='reset_goal_generator') else: reset_goal_generator = FixedResetGoal( distance_fn=initial_state_distance) # if use_tf_functions: # reset_goal_generator.get_reset_goal = common.function( # reset_goal_generator.get_reset_goal) # modify the reset-free wrapper to use the reset goal generator tf_env.pyenv.envs[0].set_reset_goal_fn( reset_goal_generator.get_reset_goal) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] replay_buffer_rev = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent_rev.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer_rev = [replay_buffer_rev.add_batch] # initialize metrics and observers train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] train_metrics += env_train_metrics train_metrics_rev = [ tf_metrics.NumberOfEpisodes(name='NumberOfEpisodesRev'), tf_metrics.EnvironmentSteps(name='EnvironmentStepsRev'), tf_metrics.AverageReturnMetric( name='AverageReturnRev', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthRev', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] train_metrics_rev += aux_info['train_metrics_rev'] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) eval_py_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_agent.policy, use_tf_function=True) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_policy_rev = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy collect_policy_rev = tf_agent_rev.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'forward'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'forward', 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) # reverse policy savers train_checkpointer_rev = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'reverse'), agent=tf_agent_rev, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics_rev, 'train_metrics_rev')) rb_checkpointer_rev = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer_rev'), max_to_keep=1, replay_buffer=replay_buffer_rev) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() train_checkpointer_rev.initialize_or_restore() rb_checkpointer_rev.initialize_or_restore() collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) collect_driver_rev = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy_rev, observers=replay_observer_rev + train_metrics_rev, num_steps=collect_steps_per_iteration) if use_tf_functions: collect_driver.run = common.function(collect_driver.run) collect_driver_rev.run = common.function(collect_driver_rev.run) tf_agent.train = common.function(tf_agent.train) tf_agent_rev.train = common.function(tf_agent_rev.train) if replay_buffer.num_frames() == 0: initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=1) initial_collect_driver_rev = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy_rev, observers=replay_observer_rev + train_metrics_rev, num_steps=1) # does not work for some reason if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) initial_collect_driver_rev.run = common.function( initial_collect_driver_rev.run) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) for iter_idx_initial in range(initial_collect_steps): if tf_env.pyenv.envs[0]._forward_or_reset_goal: initial_collect_driver.run() else: initial_collect_driver_rev.run() if FLAGS.use_reset_goals and iter_idx_initial % FLAGS.reset_goal_frequency == 0: if replay_buffer_rev.num_frames(): reset_candidates_from_forward_buffer = replay_buffer.get_next( sample_batch_size=FLAGS.num_reset_candidates // 2)[0] reset_candidates_from_reverse_buffer = replay_buffer_rev.get_next( sample_batch_size=FLAGS.num_reset_candidates // 2)[0] flat_forward_tensors = tf.nest.flatten( reset_candidates_from_forward_buffer) flat_reverse_tensors = tf.nest.flatten( reset_candidates_from_reverse_buffer) concatenated_tensors = [ tf.concat([x, y], axis=0) for x, y in zip(flat_forward_tensors, flat_reverse_tensors) ] reset_candidates = tf.nest.pack_sequence_as( reset_candidates_from_forward_buffer, concatenated_tensors) tf_env.pyenv.envs[0].set_reset_candidates(reset_candidates) else: reset_candidates = replay_buffer.get_next( sample_batch_size=FLAGS.num_reset_candidates)[0] tf_env.pyenv.envs[0].set_reset_candidates(reset_candidates) results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) dataset_rev = replay_buffer_rev.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator_rev = iter(dataset_rev) def train_step_rev(): experience_rev, _ = next(iterator_rev) return tf_agent_rev.train(experience_rev) if use_tf_functions: train_step = common.function(train_step) train_step_rev = common.function(train_step_rev) # manual data save for plotting utils np_on_cns_save(os.path.join(eval_dir, 'eval_interval.npy'), eval_interval) try: average_eval_return = np_on_cns_load( os.path.join(eval_dir, 'average_eval_return.npy')).tolist() average_eval_success = np_on_cns_load( os.path.join(eval_dir, 'average_eval_success.npy')).tolist() except: average_eval_return = [] average_eval_success = [] print('initialization_time:', time.time() - start_time) for iter_idx in range(num_iterations): start_time = time.time() if tf_env.pyenv.envs[0]._forward_or_reset_goal: time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) else: time_step, policy_state = collect_driver_rev.run( time_step=time_step, policy_state=policy_state, ) # reset goal generator updates if FLAGS.use_reset_goals and iter_idx % ( FLAGS.reset_goal_frequency * collect_steps_per_iteration) == 0: reset_candidates_from_forward_buffer = replay_buffer.get_next( sample_batch_size=FLAGS.num_reset_candidates // 2)[0] reset_candidates_from_reverse_buffer = replay_buffer_rev.get_next( sample_batch_size=FLAGS.num_reset_candidates // 2)[0] flat_forward_tensors = tf.nest.flatten( reset_candidates_from_forward_buffer) flat_reverse_tensors = tf.nest.flatten( reset_candidates_from_reverse_buffer) concatenated_tensors = [ tf.concat([x, y], axis=0) for x, y in zip(flat_forward_tensors, flat_reverse_tensors) ] reset_candidates = tf.nest.pack_sequence_as( reset_candidates_from_forward_buffer, concatenated_tensors) tf_env.pyenv.envs[0].set_reset_candidates(reset_candidates) if reset_lagrange_learning_rate > 0: reset_goal_generator.update_lagrange_multipliers() for _ in range(train_steps_per_iteration): train_loss_rev = train_step_rev() train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, train_loss.loss) logging.info('step = %d, loss_rev = %f', global_step_val, train_loss_rev.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in train_metrics: if 'Heatmap' in train_metric.name: if global_step_val % summary_interval == 0: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) else: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) for train_metric in train_metrics_rev: if 'Heatmap' in train_metric.name: if global_step_val % summary_interval == 0: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics_rev[:2]) else: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics_rev[:2]) if global_step_val % summary_interval == 0 and FLAGS.use_reset_goals: reset_goal_generator.update_summaries(step_counter=global_step) if global_step_val % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) # numpy saves for plotting average_eval_return.append(results['AverageReturn'].numpy()) average_eval_success.append(results['EvalSuccessfulEpisodes'].numpy()) np_on_cns_save( os.path.join(eval_dir, 'average_eval_return.npy'), average_eval_return) np_on_cns_save( os.path.join(eval_dir, 'average_eval_success.npy'), average_eval_success) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) train_checkpointer_rev.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) rb_checkpointer_rev.save(global_step=global_step_val) if global_step_val % video_record_interval == 0: for video_idx in range(num_videos): video_name = os.path.join(video_dir, str(global_step_val), 'video_' + str(video_idx) + '.mp4') record_video( lambda: env_load_fn( # pylint: disable=g-long-lambda name=env_name, max_episode_steps=max_episode_steps)[0], video_name, eval_py_policy, max_episode_length=eval_episode_steps) return train_loss
def run(): tf_env = tf_py_environment.TFPyEnvironment(SnakeEnv()) eval_env = tf_py_environment.TFPyEnvironment(SnakeEnv(step_limit=50)) q_net = q_network.QNetwork( tf_env.observation_spec(), tf_env.action_spec(), conv_layer_params=(), fc_layer_params=(512, 256, 128), ) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) global_counter = tf.compat.v1.train.get_or_create_global_step() agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=optimizer, td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=global_counter, gamma=0.95, epsilon_greedy=0.1, n_step_update=1, ) root_dir = os.path.join('/tf-logs', 'snake') train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_max_length, ) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, agent.collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration, ) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=agent, global_step=global_counter, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=agent.policy, global_step=global_counter, ) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer, ) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() collect_driver.run = common.function(collect_driver.run) agent.train = common.function(agent.train) random_policy = random_tf_policy.RandomTFPolicy(tf_env.time_step_spec(), tf_env.action_spec()) if replay_buffer.num_frames() >= initial_collect_steps: logging.info("We loaded memories, not doing random seed") else: logging.info("Capturing %d steps to seed with random memories", initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, random_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() train_summary_writer = tf.summary.create_file_writer(train_dir) train_summary_writer.set_as_default() avg_returns = [] avg_return_metric = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes) eval_metrics = [ avg_return_metric, tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] logging.info("Running initial evaluation") results = metric_utils.eager_compute( eval_metrics, eval_env, agent.policy, num_episodes=num_eval_episodes, train_step=global_counter, summary_writer=tf.summary.create_file_writer(eval_dir), summary_prefix='Metrics', ) avg_returns.append( (global_counter.numpy(), avg_return_metric.result().numpy())) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_counter.numpy() time_acc = 0 dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) @common.function def train_step(): experience, _ = next(iterator) return agent.train(experience) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time step = global_counter.numpy() if step % log_interval == 0: logging.info("step = %d, loss = %f", step, train_loss.loss) steps_per_sec = (step - timed_at_step) / time_acc logging.info("%.3f steps/sec", steps_per_sec) timed_at_step = step time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_counter, step_metrics=train_metrics[:2]) if step % train_checkpoint_interval == 0: train_checkpointer.save(global_step=step) if step % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=step) if step % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=step) if step % capture_interval == 0: print("Capturing run:") capture_run(os.path.join(root_dir, "snake" + str(step) + ".mp4"), eval_env, agent.policy) if step % eval_interval == 0: print("EVALUTION TIME:") results = metric_utils.eager_compute( eval_metrics, eval_env, agent.policy, num_episodes=num_eval_episodes, train_step=global_counter, summary_writer=tf.summary.create_file_writer(eval_dir), summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) avg_returns.append( (global_counter.numpy(), avg_return_metric.result().numpy()))
def train_eval( root_dir, experiment_name, # experiment name env_name='carla-v0', agent_name='sac', # agent's name num_iterations=int(1e7), actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), model_network_ctor_type='non-hierarchical', # model net input_names=['camera', 'lidar'], # names for inputs mask_names=['birdeye'], # names for masks preprocessing_combiner=tf.keras.layers.Add( ), # takes a flat list of tensors and combines them actor_lstm_size=(40, ), # lstm size for actor critic_lstm_size=(40, ), # lstm size for critic actor_output_fc_layers=(100, ), # lstm output critic_output_fc_layers=(100, ), # lstm output epsilon_greedy=0.1, # exploration parameter for DQN q_learning_rate=1e-3, # q learning rate for DQN ou_stddev=0.2, # exploration paprameter for DDPG ou_damping=0.15, # exploration parameter for DDPG dqda_clipping=None, # for DDPG exploration_noise_std=0.1, # exploration paramter for td3 actor_update_period=2, # for td3 # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=int(1e5), # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, initial_model_train_steps=100000, # initial model training batch_size=256, model_batch_size=32, # model training batch size sequence_length=4, # number of timesteps to train model actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, model_learning_rate=1e-4, # learning rate for model training td_errors_loss_fn=tf.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging num_images_per_summary=1, # images for each summary train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, gpu_allow_growth=True, # GPU memory growth gpu_memory_limit=None, # GPU memory limit action_repeat=1 ): # Name of single observation channel, ['camera', 'lidar', 'birdeye'] # Setup GPU gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit) ]) # Get train and eval directories root_dir = os.path.expanduser(root_dir) root_dir = os.path.join(root_dir, env_name, experiment_name) # Get summary writers summary_writer = tf.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() # Eval metrics eval_metrics = [ tf_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthEvalPolicy', buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() # Whether to record for summary with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create Carla environment if agent_name == 'latent_sac': py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names + mask_names, action_repeat=action_repeat) elif agent_name == 'dqn': py_env, eval_py_env = load_carla_env(env_name='carla-v0', discrete=True, obs_channels=input_names, action_repeat=action_repeat) else: py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names, action_repeat=action_repeat) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) fps = int(np.round(1.0 / (py_env.dt * action_repeat))) # Specs time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() ## Make tf agent if agent_name == 'latent_sac': # Get model network for latent sac if model_network_ctor_type == 'hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical elif model_network_ctor_type == 'non-hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical else: raise NotImplementedError model_net = model_network_ctor(input_names, input_names + mask_names) # Get the latent spec latent_size = model_net.latent_size latent_observation_spec = tensor_spec.TensorSpec((latent_size, ), dtype=tf.float32) latent_time_step_spec = ts.time_step_spec( observation_spec=latent_observation_spec) # Get actor and critic net actor_net = actor_distribution_network.ActorDistributionNetwork( latent_observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (latent_observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) # Build the inner SAC agent based on latent space inner_agent = sac_agent.SacAgent( latent_time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) inner_agent.initialize() # Build the latent sac agent tf_agent = latent_sac_agent.LatentSACAgent( time_step_spec, action_spec, inner_agent=inner_agent, model_network=model_net, model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), model_batch_size=model_batch_size, num_images_per_summary=num_images_per_summary, sequence_length=sequence_length, gradient_clipping=gradient_clipping, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, fps=fps) else: # Set up preprosessing layers for dictionary observation inputs preprocessing_layers = collections.OrderedDict() for name in input_names: preprocessing_layers[name] = Preprocessing_Layer(32, 256) if len(input_names) < 2: preprocessing_combiner = None if agent_name == 'dqn': q_rnn_net = q_rnn_network.QRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_rnn_net, epsilon_greedy=epsilon_greedy, n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=q_learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'ddpg' or agent_name == 'td3': actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) if agent_name == 'ddpg': tf_agent = ddpg_agent.DdpgAgent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) elif agent_name == 'td3': tf_agent = td3_agent.Td3Agent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'sac': actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=normal_projection_net) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_distribution_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math. squared_difference, # make critic loss dimension compatible gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) else: raise NotImplementedError tf_agent.initialize() # Get replay buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, # No parallel environments max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] # Train metrics env_steps = tf_metrics.EnvironmentSteps() average_return = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] # Get policies # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) eval_policy = tf_agent.policy initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) collect_policy = tf_agent.collect_policy # Checkpointers train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=2) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step, max_to_keep=2) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() # Collect driver initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) # Optimize the performance by using tf functions initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if (env_steps.result() == 0 or replay_buffer.num_frames() == 0): logging.info( 'Initializing replay buffer by collecting experience for %d steps' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() if agent_name == 'latent_sac': compute_summaries(eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=1, num_episodes_to_render=1, model_net=model_net, fps=10, image_keys=input_names + mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=1, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Dataset generates trajectories with shape [Bxslx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=sequence_length + 1).prefetch(3) iterator = iter(dataset) # Get train step def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) train_step = common.function(train_step) if agent_name == 'latent_sac': def train_model_step(): experience, _ = next(iterator) return tf_agent.train_model(experience) train_model_step = common.function(train_model_step) # Training initializations time_step = None time_acc = 0 env_steps_before = env_steps.result().numpy() # Start training for iteration in range(num_iterations): start_time = time.time() if agent_name == 'latent_sac' and iteration < initial_model_train_steps: train_model_step() else: # Run collect time_step, _ = collect_driver.run(time_step=time_step) # Train an iteration for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time # Log training information if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() # Get training metrics for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) # Evaluation if global_step.numpy() % eval_interval == 0: # Log evaluation metrics if agent_name == 'latent_sac': compute_summaries( eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, model_net=model_net, fps=10, image_keys=input_names + mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Save checkpoints global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_allowlist='position', eval_env_name=None, num_iterations=1000000, # Params for networks. actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), num_parallel_environments=1, # Params for collect initial_collect_episodes=1, collect_episodes_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=256, critic_learning_rate=3e-4, train_sequence_length=20, actor_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for RNN SAC on DM control.""" root_dir = os.path.expanduser(root_dir) summary_writer = tf.compat.v2.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_allowlist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_allowlist=[observations_allowlist]) ] else: env_wrappers = [] env_load_fn = functools.partial(suite_dm_control.load, task_name=task_name, env_wrappers=env_wrappers) if num_parallel_environments == 1: py_env = env_load_fn(env_name) else: py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(eval_env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_rnn_network.CriticRnnNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] env_steps = tf_metrics.EnvironmentSteps(prefix='Train') average_return = tf_metrics.AverageReturnMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(prefix='Train'), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if env_steps.result() == 0 or replay_buffer.num_frames() == 0: logging.info( 'Initializing replay buffer by collecting experience for %d episodes ' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.result()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) time_acc = 0 env_steps_before = env_steps.result().numpy() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): # Reduce filter_fn over full trajectory sampled. The sequence is kept only # if all elements except for the last one pass the filter. This is to # allow training on terminal steps. return tf.reduce_all(~trajectories.is_boundary()[:-1]) dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=train_sequence_length + 1).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() start_env_steps = env_steps.result() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) episode_steps = env_steps.result() - start_env_steps # TODO(b/152648849) for _ in range(episode_steps): for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.compat.v2.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train(self, training_iterations=TRAINING_ITERATIONS, training_stock_list=None): self.reset(training_stock_list) train_dir = 'training_data_progress/train-' + self.name eval_dir = 'training_data_progress/eval-' + self.name replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=self.tf_agent.collect_data_spec, batch_size=self.tf_training_env.batch_size, max_length=MAX_BUFFER_SIZE) summaries_flush_secs = 10 eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=NUM_EVAL_EPISODES), tf_metrics.AverageEpisodeLengthMetric( buffer_size=NUM_EVAL_EPISODES) ] global_step = self.tf_agent.train_step_counter with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % LOG_INTERVAL, 0)): replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric( buffer_size=NUM_EVAL_EPISODES, batch_size=self.tf_training_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=NUM_EVAL_EPISODES, batch_size=self.tf_training_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(self.tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( self.tf_training_env.time_step_spec(), self.tf_training_env.action_spec()) collect_policy = self.tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=self.tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver_random = dynamic_step_driver.DynamicStepDriver( self.tf_training_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=INIT_COLLECT_STEPS) initial_collect_driver_random.run = common.function( initial_collect_driver_random.run) collect_driver = dynamic_step_driver.DynamicStepDriver( self.tf_training_env, collect_policy, observers=replay_observer + train_metrics, num_steps=STEP_ITERATIONS) collect_driver.run = common.function(collect_driver.run) self.tf_agent.train = common.function(self.tf_agent.train) # Collect some initial data. # Random random_policy = random_tf_policy.RandomTFPolicy( self.tf_training_env.time_step_spec(), self.tf_training_env.action_spec()) avg_return, avg_return_per_step, avg_daily_percentage = self.compute_avg_return( random_policy) print( 'Random:\n\tAverage Return = {0}\n\tAverage Return Per Step = {1}\n\tPercent = {2}%' .format(avg_return, avg_return_per_step, avg_daily_percentage)) self.gym_training_env.save_feature_distribution(self.name) # Agent avg_return, avg_return_per_step, avg_daily_percentage = self.compute_avg_return( self.tf_agent.policy) print( 'Agent :\n\tAverage Return = {0}\n\tAverage Return Per Step = {1}\n\tPercent = {2}%' .format(avg_return, avg_return_per_step, avg_daily_percentage)) self.eval_env.reset() self.eval_env.run_and_save_evaluation(str(0)) self.gym_training_env.save_feature_distribution(self.name) evaluations = [self.get_evaluation()] returns = [self.eval_env.returns] actions_over_time_list = [self.eval_env.action_sets_over_time] # Collect initial replay data. print( 'Initializing replay buffer by collecting experience for {} steps with ' 'a random policy.'.format(INIT_COLLECT_STEPS)) initial_collect_driver_random.run() results = metric_utils.eager_compute( eval_metrics, self.tf_training_env, eval_policy, num_episodes=NUM_EVAL_EPISODES, train_step=global_step, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state( self.tf_training_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=BATCH_SIZE, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(BATCH_SIZE).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def _train_step(): try: experience, _ = next(iterator) return self.tf_agent.train(experience) except Exception as e: print("Caught Exception:", e) return 1e-20 train_step = common.function(_train_step) for _ in range(training_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(STEP_ITERATIONS): train_loss = train_step() time_acc += time.time() - start_time self.global_step_val = global_step.numpy() if self.global_step_val % LOG_INTERVAL == 0: steps_per_sec = (self.global_step_val - timed_at_step) / time_acc print( self.name, '\nstep = {0:d}:\n\tloss = {1:f}\n\t{2:.3f} steps/sec'. format(self.global_step_val, train_loss.loss, steps_per_sec)) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = self.global_step_val time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if self.global_step_val % EVAL_INTERVAL == 0: results = metric_utils.eager_compute( eval_metrics, self.tf_training_env, eval_policy, num_episodes=NUM_EVAL_EPISODES, train_step=global_step, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) avg_return, avg_return_per_step, avg_daily_percentage = self.compute_avg_return( self.tf_agent.policy) print( self.name, '\nstep = {0}:\n\tloss = {1}\n\tAverage Return = {2}\n\tAverage Return Per Step = {3}\n\tPercent = {4}%' .format(self.global_step_val, train_loss.loss, avg_return, avg_return_per_step, avg_daily_percentage)) self.eval_env.reset() self.eval_env.run_and_save_evaluation( str(self.global_step_val // EVAL_INTERVAL)) self.gym_training_env.save_feature_distribution(self.name) if avg_daily_percentage == returns[-1]: "---- Average return did not change since last time. Breaking loop." break evaluations.append(self.get_evaluation()) returns.append(self.eval_env.returns) actions_over_time_list.append( self.eval_env.action_sets_over_time) train_checkpointer.save(global_step=self.global_step_val) policy_checkpointer.save(global_step=self.global_step_val) rb_checkpointer.save(global_step=self.global_step_val) training_report = util.load_training_report() agent_report = training_report.get(self.name, dict()) agent_report["Training Results"] = returns agent_report["Evaluations"] = [max(e, 0.0) for e in evaluations] bins = [0.1 * i - 0.0000001 for i in range(11)] agent_report["Histograms"] = [ str(list(map(int, np.histogram(actions, bins, density=True)[0]))) for actions in actions_over_time_list ] training_report[self.name] = agent_report util.save_training_report(training_report) print("---- Average-daily-percentage over training period for", self.name) print("\t\t", avg_daily_percentage) self.save() self.reset()
def train(root_dir, agent, environment, training_loops, steps_per_loop, additional_metrics=(), training_data_spec_transformation_fn=None): """Perform `training_loops` iterations of training. Checkpoint results. If one or more baseline_reward_fns are provided, the regret is computed against each one of them. Here is example baseline_reward_fn: def baseline_reward_fn(observation, per_action_reward_fns): rewards = ... # compute reward for each arm optimal_action_reward = ... # take the maximum reward return optimal_action_reward Args: root_dir: path to the directory where checkpoints and metrics will be written. agent: an instance of `TFAgent`. environment: an instance of `TFEnvironment`. training_loops: an integer indicating how many training loops should be run. steps_per_loop: an integer indicating how many driver steps should be executed and presented to the trainer during each training loop. additional_metrics: Tuple of metric objects to log, in addition to default metrics `NumberOfEpisodes`, `AverageReturnMetric`, and `AverageEpisodeLengthMetric`. training_data_spec_transformation_fn: Optional function that transforms the data items before they get to the replay buffer. """ # TODO(b/127641485): create evaluation loop with configurable metrics. if training_data_spec_transformation_fn is None: data_spec = agent.policy.trajectory_spec else: data_spec = training_data_spec_transformation_fn( agent.policy.trajectory_spec) replay_buffer = get_replay_buffer(data_spec, environment.batch_size, steps_per_loop) # `step_metric` records the number of individual rounds of bandit interaction; # that is, (number of trajectories) * batch_size. step_metric = tf_metrics.EnvironmentSteps() metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.AverageEpisodeLengthMetric( batch_size=environment.batch_size) ] + list(additional_metrics) if isinstance(environment.reward_spec(), dict): metrics += [ tf_metrics.AverageReturnMultiMetric( reward_spec=environment.reward_spec(), batch_size=environment.batch_size) ] else: metrics += [ tf_metrics.AverageReturnMetric(batch_size=environment.batch_size) ] if training_data_spec_transformation_fn is not None: add_batch_fn = lambda data: replay_buffer.add_batch( # pylint: disable=g-long-lambda training_data_spec_transformation_fn(data)) else: add_batch_fn = replay_buffer.add_batch observers = [add_batch_fn, step_metric] + metrics driver = dynamic_step_driver.DynamicStepDriver(env=environment, policy=agent.collect_policy, num_steps=steps_per_loop * environment.batch_size, observers=observers) training_loop = get_training_loop_fn(driver, replay_buffer, agent, steps_per_loop) checkpoint_manager = restore_and_get_checkpoint_manager( root_dir, agent, metrics, step_metric) saver = policy_saver.PolicySaver(agent.policy) summary_writer = tf.summary.create_file_writer(root_dir) summary_writer.set_as_default() for _ in range(training_loops): training_loop() metric_utils.log_metrics(metrics) for metric in metrics: metric.tf_summaries(train_step=step_metric.result()) checkpoint_manager.save() saver.save(os.path.join(root_dir, 'policy_%d' % step_metric.result()))
def train_eval( load_root_dir, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args safety_critic_joint_fc_layers=None, safety_critic_lr=3e-4, safety_critic_bias_init_val=None, safety_critic_kernel_scale=None, n_envs=None, target_safety=0.2, fail_weight=None, # Params for train num_global_steps=10000, batch_size=256, # Params for eval run_eval=False, eval_metrics=[], num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, debug_summaries=False, seed=None): if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format( agent_class) agent_class = ALGOS.get(agent_class) train_ckpt_dir = osp.join(load_root_dir, 'train') rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer') py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(py_env) if monitor: vid_path = os.path.join(load_root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) if run_eval: eval_dir = os.path.join(load_root_dir, 'eval') n_envs = n_envs or num_eval_episodes eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric( prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs) ] + [ tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name)) for m in eval_metrics ] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: eval_tf_env.pyenv.seed(seeds) except: pass global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=False) else: tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=False) collect_data_spec = tf_agent.collect_data_spec replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=1000000) replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer) tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent) if agent_class in SAFETY_AGENTS: target_safety = target_safety or tf_agent._target_safety loaded_train_steps = global_step.numpy() logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir, loaded_train_steps) global_step.assign(0) tf.summary.experimental.set_step(global_step) thresholds = [target_safety, 0.5] sc_metrics = [ tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds) ] if seed: tf.compat.v1.set_random_seed(seed) summaries_flush_secs = 10 timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S') offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp) config_saver = gin.tf.GinConfigSaverHook(offline_train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() sc_summary_writer = tf.compat.v2.summary.create_file_writer( offline_train_dir, flush_millis=summaries_flush_secs * 1000) sc_summary_writer.set_as_default() if safety_critic_kernel_scale is not None: ki = tf.compat.v1.variance_scaling_initializer( scale=safety_critic_kernel_scale, mode='fan_in', distribution='truncated_normal') else: ki = tf.compat.v1.keras.initializers.VarianceScaling( scale=1. / 3., mode='fan_in', distribution='uniform') if safety_critic_bias_init_val is not None: bi = tf.constant_initializer(safety_critic_bias_init_val) else: bi = None sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=safety_critic_joint_fc_layers, kernel_initializer=ki, value_bias_initializer=bi, name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') optimizer = tf.keras.optimizers.Adam(safety_critic_lr) sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic') sc_checkpointer = common.Checkpointer( ckpt_dir=sc_net_off_ckpt_dir, safety_critic=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=optimizer, global_step=global_step, max_to_keep=5) sc_checkpointer.initialize_or_restore() resample_counter = py_metrics.CounterMetric('ActionResampleCounter') eval_policy = agents.SafeActorPolicyRSVar( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_net, safety_critic_network=sc_net_off, safety_threshold=target_safety, resample_counter=resample_counter, training=True) dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2, sample_batch_size=batch_size // 2).prefetch(3) data = iter(dataset) full_data = replay_buffer.gather_all() fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool) fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, fail_mask), full_data) init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data) before_fail_mask = tf.roll(fail_mask, [-1], axis=[1]) after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1]) before_fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data) after_init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, after_init_mask), full_data) filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask)) filter_mask = tf.pad( filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]]) n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy() failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=n_failures, dataset_window_shift=1) data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask) sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size // 2, num_steps=2).prefetch(3) neg_data = iter(sc_dataset_neg) get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0] eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step, after_init_step, get_action) losses = [] mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss') target_update = train_utils.get_target_updater(sc_net_off, target_sc_net_off) with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): while global_step.numpy() < num_global_steps: pos_experience, _ = next(data) neg_experience, _ = next(neg_data) exp = data_utils.concat_batches(pos_experience, neg_experience, collect_data_spec) boundary_mask = tf.logical_not(exp.is_boundary()[:, 0]) exp = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, boundary_mask), exp) safe_rew = exp.observation['task_agn_rew'][:, 1] if fail_weight: weights = tf.where(tf.cast(safe_rew, tf.bool), fail_weight / 0.5, (1 - fail_weight) / 0.5) else: weights = None train_loss, sc_loss, lam_loss = train_step( exp, safe_rew, tf_agent, sc_net=sc_net_off, target_sc_net=target_sc_net_off, metrics=sc_metrics, weights=weights, target_safety=target_safety, optimizer=optimizer, target_update=target_update, debug_summaries=debug_summaries) global_step.assign_add(1) global_step_val = global_step.numpy() losses.append( (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy())) mean_loss(train_loss) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='sc_loss', data=sc_loss, step=global_step_val) tf.compat.v2.summary.scalar(name='lam_loss', data=lam_loss, step=global_step_val) if global_step_val % summary_interval == 0: tf.compat.v2.summary.scalar(name=mean_loss.name, data=mean_loss.result(), step=global_step_val) if global_step_val % summary_interval == 0: with tf.name_scope('Metrics'): for metric in sc_metrics: if len(tf.squeeze(metric.result()).shape) == 0: tf.compat.v2.summary.scalar(name=metric.name, data=metric.result(), step=global_step_val) else: fmt_str = '_{}'.format(thresholds[0]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[0], step=global_step_val) fmt_str = '_{}'.format(thresholds[1]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[1], step=global_step_val) metric.reset_states() if global_step_val % eval_interval == 0: eval_sc(sc_net_off, step=global_step_val) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries( train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and global_step_val % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action( monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug( 'saved rollout at timestep %d, rollout length: %d, %4.2f sec', global_step_val, ep_len, time.time() - monitor_start) if global_step_val % train_checkpoint_interval == 0: sc_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def train_eval(root_dir, tf_env, eval_tf_env, agent, num_iterations, initial_collect_steps, collect_steps_per_iteration, replay_buffer_capacity, train_steps_per_iteration, batch_size, use_tf_functions, num_eval_episodes, eval_interval, train_checkpoint_interval, policy_checkpoint_interval, rb_checkpoint_interval, log_interval, summary_interval, summaries_flush_secs): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ #tf_metrics.ChosenActionHistogram(buffer_size=num_eval_episodes), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), #tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_env eval_tf_env = eval_tf_env tf_agent = agent train_metrics = [ #tf_metrics.ChosenActionHistogram(), tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=1), #tf_metrics.AverageEpisodeLengthMetric(), ] diverged = False eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, max_to_keep=1, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, max_to_keep=1, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() best_policy = -1000 if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) #Collect initial replay data. dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if np.isnan(train_loss.loss).any(): diverged = True break elif np.isinf(train_loss.loss).any(): diverged = True break if global_step.numpy() % log_interval == 0: print('step = {0}, loss = {1}'.format(global_step.numpy(), train_loss.loss)) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc print('{0} steps/sec'.format(steps_per_sec)) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if results["AverageReturn"].numpy() > best_policy: print("New best policy found") print(results["AverageReturn"].numpy()) best_policy = results["AverageReturn"].numpy() policy_checkpointer.save(global_step=global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, env_name='sawyer_reach', num_iterations=3000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, gamma=0.99, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=200000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, random_seed=0, max_future_steps=50, actor_std=None, log_subset=None, ): """A simple train and eval for SAC.""" np.random.seed(random_seed) tf.random.set_seed(random_seed) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env, eval_tf_env, obs_dim = c_learning_envs.load(env_name) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() if actor_std is None: proj_net = tanh_normal_projection_network.TanhNormalProjectionNetwork else: proj_net = functools.partial( tanh_normal_projection_network.TanhNormalProjectionNetwork, std_transform=lambda t: actor_std * tf.ones_like(t)) actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=proj_net) critic_net = c_learning_utils.ClassifierCriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = c_learning_agent.CLearningAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=bce_loss, gamma=gamma, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), c_learning_utils.FinalDistance( buffer_size=num_eval_episodes, obs_dim=obs_dim), c_learning_utils.MinimumDistance( buffer_size=num_eval_episodes, obs_dim=obs_dim), c_learning_utils.DeltaDistance( buffer_size=num_eval_episodes, obs_dim=obs_dim), ] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), c_learning_utils.InitialDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, obs_dim=obs_dim), c_learning_utils.FinalDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, obs_dim=obs_dim), c_learning_utils.MinimumDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, obs_dim=obs_dim), c_learning_utils.DeltaDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, obs_dim=obs_dim), ] if log_subset is not None: start_index, end_index = log_subset for name, metrics in [('train', train_metrics), ('eval', eval_metrics)]: metrics.extend([ c_learning_utils.InitialDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size if name == 'train' else 10, obs_dim=obs_dim, start_index=start_index, end_index=end_index, name='SubsetInitialDistance'), c_learning_utils.FinalDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size if name == 'train' else 10, obs_dim=obs_dim, start_index=start_index, end_index=end_index, name='SubsetFinalDistance'), c_learning_utils.MinimumDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size if name == 'train' else 10, obs_dim=obs_dim, start_index=start_index, end_index=end_index, name='SubsetMinimumDistance'), c_learning_utils.DeltaDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size if name == 'train' else 10, obs_dim=obs_dim, start_index=start_index, end_index=end_index, name='SubsetDeltaDistance'), ]) eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=None) train_checkpointer.initialize_or_restore() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Save the hyperparameters operative_filename = os.path.join(root_dir, 'operative.gin') with tf.compat.v1.gfile.Open(operative_filename, 'w') as f: f.write(gin.operative_config_str()) logging.info(gin.operative_config_str()) if replay_buffer.num_frames() == 0: # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps ' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=max_future_steps) dataset = dataset.unbatch().filter(_filter_invalid_transition) dataset = dataset.batch(batch_size, drop_remainder=True) goal_fn = functools.partial( c_learning_utils.goal_fn, batch_size=batch_size, obs_dim=obs_dim, gamma=gamma) dataset = dataset.map(goal_fn) dataset = dataset.prefetch(5) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) global_step_val = global_step.numpy() while global_step_val < num_iterations: start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) return train_loss
def train_eval( root_dir, environment_name="broken_reacher", num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), initial_collect_steps=10000, real_initial_collect_steps=10000, collect_steps_per_iteration=1, real_collect_interval=10, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, classifier_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=False, train_on_real=False, delta_r_warmup=0, random_seed=0, checkpoint_dir=None, ): """A simple train and eval for SAC.""" np.random.seed(random_seed) tf.random.set_seed(random_seed) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, "train") eval_dir = os.path.join(root_dir, "eval") train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) if environment_name == "broken_reacher": get_env_fn = darc_envs.get_broken_reacher_env elif environment_name == "half_cheetah_obstacle": get_env_fn = darc_envs.get_half_cheetah_direction_env elif environment_name == "inverted_pendulum": get_env_fn = darc_envs.get_inverted_pendulum_env elif environment_name.startswith("broken_joint"): base_name = environment_name.split("broken_joint_")[1] get_env_fn = functools.partial(darc_envs.get_broken_joint_env, env_name=base_name) elif environment_name.startswith("falling"): base_name = environment_name.split("falling_")[1] get_env_fn = functools.partial(darc_envs.get_falling_env, env_name=base_name) else: raise NotImplementedError("Unknown environment: %s" % environment_name) eval_name_list = ["sim", "real"] eval_env_list = [get_env_fn(mode) for mode in eval_name_list] eval_metrics_list = [] for name in eval_name_list: eval_metrics_list.append([ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, name="AverageReturn_%s" % name), ]) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env_real = get_env_fn("real") if train_on_real: tf_env = get_env_fn("real") else: tf_env = get_env_fn("sim") time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=( tanh_normal_projection_network.TanhNormalProjectionNetwork), ) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer="glorot_uniform", last_kernel_initializer="glorot_uniform", ) classifier = classifiers.build_classifier(observation_spec, action_spec) tf_agent = darc_agent.DarcAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, classifier=classifier, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), classifier_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=classifier_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, ) replay_observer = [replay_buffer.add_batch] real_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, ) real_replay_observer = [real_replay_buffer.add_batch] sim_train_metrics = [ tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesSim"), tf_metrics.EnvironmentSteps(name="EnvironmentStepsSim"), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageReturnSim", ), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageEpisodeLengthSim", ), ] real_train_metrics = [ tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesReal"), tf_metrics.EnvironmentSteps(name="EnvironmentStepsReal"), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageReturnReal", ), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageEpisodeLengthReal", ), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup( sim_train_metrics + real_train_metrics, "train_metrics"), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "policy"), policy=eval_policy, global_step=global_step, ) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "replay_buffer"), max_to_keep=1, replay_buffer=(replay_buffer, real_replay_buffer), ) if checkpoint_dir is not None: checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) assert checkpoint_path is not None train_checkpointer._load_status = train_checkpointer._checkpoint.restore( # pylint: disable=protected-access checkpoint_path) train_checkpointer._load_status.initialize_or_restore() # pylint: disable=protected-access else: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if replay_buffer.num_frames() == 0: initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + sim_train_metrics, num_steps=initial_collect_steps, ) real_initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env_real, initial_collect_policy, observers=real_replay_observer + real_train_metrics, num_steps=real_initial_collect_steps, ) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + sim_train_metrics, num_steps=collect_steps_per_iteration, ) real_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env_real, collect_policy, observers=real_replay_observer + real_train_metrics, num_steps=collect_steps_per_iteration, ) config_str = gin.operative_config_str() logging.info(config_str) with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"), "w") as f: f.write(config_str) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) real_initial_collect_driver.run = common.function( real_initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) real_collect_driver.run = common.function(real_collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if replay_buffer.num_frames() == 0: logging.info( "Initializing replay buffer by collecting experience for %d steps with " "a random policy.", initial_collect_steps, ) initial_collect_driver.run() real_initial_collect_driver.run() for eval_name, eval_env, eval_metrics in zip(eval_name_list, eval_env_list, eval_metrics_list): metric_utils.eager_compute( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix="Metrics-%s" % eval_name, ) metric_utils.log_metrics(eval_metrics) time_step = None real_time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = (replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5)) real_dataset = (real_replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5)) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) real_iterator = iter(real_dataset) def train_step(): experience, _ = next(iterator) real_experience, _ = next(real_iterator) return tf_agent.train(experience, real_experience=real_experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) assert not policy_state # We expect policy_state == (). if (global_step.numpy() % real_collect_interval == 0 and global_step.numpy() >= delta_r_warmup): real_time_step, policy_state = real_collect_driver.run( time_step=real_time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info("step = %d, loss = %f", global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info("%.3f steps/sec", steps_per_sec) tf.compat.v2.summary.scalar(name="global_steps_per_sec", data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in sim_train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=sim_train_metrics[:2]) for train_metric in real_train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=real_train_metrics[:2]) if global_step_val % eval_interval == 0: for eval_name, eval_env, eval_metrics in zip( eval_name_list, eval_env_list, eval_metrics_list): metric_utils.eager_compute( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix="Metrics-%s" % eval_name, ) metric_utils.log_metrics(eval_metrics) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def train_eval( root_dir, env_name='Blob2d-v1', num_iterations=100000, train_sequence_length=1, collect_steps_per_iteration=1, initial_collect_steps=1500, replay_buffer_max_length=10000, batch_size=64, learning_rate=1e-3, num_eval_episodes=10, eval_interval=1000, # Params for QNetwork fc_layer_params=(100, ), use_tf_functions=False, ## train params train_steps_per_iteration=1, train_checkpoint_interval=1000, policy_checkpoint_interval=1000, rb_checkpoint_interval=1000, n_step_update=1, ## Params for Summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) if train_sequence_length != 1 and n_step_update != 1: raise NotImplementedError( 'train_eval does not currently support n-step updates with stateful ' 'networks (i.e., RNNs)') env = suite_gym.load('Blob2d-v1') tf_env = tf_py_environment.TFPyEnvironment(env) action_spec = tf_env.action_spec() fc_layer_params = (100, ) q_net = q_network.QNetwork(tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=fc_layer_params) agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=global_step) agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = agent.policy collect_policy = agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_max_length) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer(ckpt_dir=train_dir, agent=agent, global_step=global_step, metrics=metric_utils.MetricsGroup( train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return agent.train(experience) if use_tf_functions: train_step = common.function(train_step) # Main Training loop. for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train( root_dir, load_root_dir=None, env_load_fn=None, env_name=None, num_parallel_environments=1, # pylint: disable=unused-argument agent_class=None, initial_collect_random=True, # pylint: disable=unused-argument initial_collect_driver_class=None, collect_driver_class=None, num_global_steps=1000000, train_steps_per_iteration=1, train_metrics=None, # Safety Critic training args train_sc_steps=10, train_sc_interval=300, online_critic=False, # Params for eval run_eval=False, num_eval_episodes=30, eval_interval=1000, eval_metrics_callback=None, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, keep_rb_checkpoint=False, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, early_termination_fn=None, env_metric_factories=None): # pylint: disable=unused-argument """A simple train and eval for SC-SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() train_metrics = train_metrics or [] if run_eval: eval_dir = os.path.join(root_dir, 'eval') eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = env_load_fn(env_name) if not isinstance(tf_env, tf_py_environment.TFPyEnvironment): tf_env = tf_py_environment.TFPyEnvironment(tf_env) if run_eval: eval_py_env = env_load_fn(env_name) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() print('obs spec:', observation_spec) print('action spec:', action_spec) if online_critic: resample_metric = tf_py_metric.TfPyMetric( py_metrics.CounterMetric('unsafe_ac_samples')) tf_agent = agent_class(time_step_spec, action_spec, train_step_counter=global_step, resample_metric=resample_metric) else: tf_agent = agent_class(time_step_spec, action_spec, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. collect_data_spec = tf_agent.collect_data_spec logging.info('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, max_length=1000000) logging.info('RB capacity: %i', replay_buffer.capacity) logging.info('ReplayBuffer Collect data spec: %s', collect_data_spec) agent_observers = [replay_buffer.add_batch] if online_critic: online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, max_length=10000) online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer') online_rb_checkpointer = common.Checkpointer( ckpt_dir=online_rb_ckpt_dir, max_to_keep=1, replay_buffer=online_replay_buffer) clear_rb = common.function(online_replay_buffer.clear) agent_observers.append(online_replay_buffer.add_batch) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] if not online_critic: eval_policy = tf_agent.policy else: eval_policy = tf_agent._safe_policy # pylint: disable=protected-access initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) if not online_critic: collect_policy = tf_agent.collect_policy else: collect_policy = tf_agent._safe_policy # pylint: disable=protected-access train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) safety_critic_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'safety_critic'), safety_critic=tf_agent._safety_critic_network, # pylint: disable=protected-access global_step=global_step) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) if load_root_dir: load_root_dir = os.path.expanduser(load_root_dir) load_train_dir = os.path.join(load_root_dir, 'train') misc.load_pi_ckpt(load_train_dir, tf_agent) # loads tf_agent if load_root_dir is None: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() safety_critic_checkpointer.initialize_or_restore() collect_driver = collect_driver_class(tf_env, collect_policy, observers=agent_observers + train_metrics) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if not rb_checkpointer.checkpoint_exists: logging.info('Performing initial collection ...') common.function( initial_collect_driver_class(tf_env, initial_collect_policy, observers=agent_observers + train_metrics).run)() last_id = replay_buffer._get_last_id() # pylint: disable=protected-access logging.info('Data saved after initial collection: %d steps', last_id) tf.print( replay_buffer._get_rows_for_id(last_id), # pylint: disable=protected-access output_stream=logging.info) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if FLAGS.viz_pm: eval_fig_dir = osp.join(eval_dir, 'figs') if not tf.io.gfile.isdir(eval_fig_dir): tf.io.gfile.makedirs(eval_fig_dir) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2).prefetch(3) iterator = iter(dataset) if online_critic: online_dataset = online_replay_buffer.as_dataset( num_parallel_calls=3, num_steps=2).prefetch(3) online_iterator = iter(online_dataset) @common.function def critic_train_step(): """Builds critic training step.""" experience, buf_info = next(online_iterator) if env_name in [ 'IndianWell', 'IndianWell2', 'IndianWell3', 'DrunkSpider', 'DrunkSpiderShort' ]: safe_rew = experience.observation['task_agn_rew'] else: safe_rew = agents.process_replay_buffer( online_replay_buffer, as_tensor=True) safe_rew = tf.gather(safe_rew, tf.squeeze(buf_info.ids), axis=1) ret = tf_agent.train_sc(experience, safe_rew) clear_rb() return ret @common.function def train_step(): experience, _ = next(iterator) ret = tf_agent.train(experience) return ret if not early_termination_fn: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss') if online_critic: mean_resample_ac = tf.keras.metrics.Mean( name='mean_unsafe_ac_samples') resample_metric.reset() while (global_step.numpy() <= num_global_steps and not early_termination_fn()): # Collect and train. start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) if online_critic: mean_resample_ac(resample_metric.result()) resample_metric.reset() if time_step.is_last(): resample_ac_freq = mean_resample_ac.result() mean_resample_ac.reset_states() tf.compat.v2.summary.scalar(name='unsafe_ac_samples', data=resample_ac_freq, step=global_step) for _ in range(train_steps_per_iteration): train_loss = train_step() mean_train_loss(train_loss.loss) if online_critic: if global_step.numpy() % train_sc_interval == 0: for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # pylint: disable=unused-variable total_loss = mean_train_loss.result() mean_train_loss.reset_states() # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True break else: loss_divergence_counter = 0 time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), total_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) safety_critic_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: if online_critic: online_rb_checkpointer.save(global_step=global_step_val) rb_checkpointer.save(global_step=global_step_val) if run_eval and global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if FLAGS.viz_pm: savepath = 'step{}.png'.format(global_step_val) savepath = osp.join(eval_fig_dir, savepath) misc.record_episode_vis_summary(eval_tf_env, eval_policy, savepath) if not keep_rb_checkpoint: misc.cleanup_checkpoints(rb_ckpt_dir) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy())) return total_loss
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=5e5, train_sequence_length=1, # Params for QNetwork fc_layer_params=( 64, 64, ), # Params for QRnnNetwork input_fc_layer_params=(50, ), lstm_size=(6, ), output_fc_layer_params=(30, ), # Params for collect initial_collect_steps=2000, collect_steps_per_iteration=6, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=6, batch_size=32, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=1, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') clusters = pickle.load(open('clusters.pickle', 'rb')) graph = nx.read_gpickle('graph.gpickle') print(graph.nodes) train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name, gym_kwargs={ 'graph': graph, 'clusters': clusters })) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name, gym_kwargs={ 'graph': graph, 'clusters': clusters })) if train_sequence_length != 1 and n_step_update != 1: raise NotImplementedError( 'train_eval does not currently support n-step updates with stateful ' 'networks (i.e., RNNs)') action_spec = tf_env.action_spec() num_actions = action_spec.maximum - action_spec.minimum + 1 if train_sequence_length > 1: q_net = create_recurrent_network(input_fc_layer_params, lstm_size, output_fc_layer_params, num_actions) else: q_net = create_feedforward_network(fc_layer_params, num_actions) train_sequence_length = n_step_update q_net = GATNetwork(tf_env.observation_spec(), tf_env.action_spec(), graph) #time_step = tf_env.reset() #q_net(time_step.observation, time_step.step_type) #q_net = actor_distribution_network.ActorDistributionNetwork( # tf_env.observation_spec(), # tf_env.action_spec(), # fc_layer_params=fc_layer_params) #q_net = QNetwork(tf_env.observation_spec(), tf_env.action_spec(), 30) # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) #critic_net = ddpg.critic_network.CriticNetwork( #(tf_env.observation_spec(), tf_env.action_spec()), #observation_fc_layer_params=None, #action_fc_layer_params=None, #joint_fc_layer_params=(64,64,), #kernel_initializer='glorot_uniform', #last_kernel_initializer='glorot_uniform') #tf_agent = DdpgAgent(tf_env.time_step_spec(), # tf_env.action_spec(), # actor_network=q_net, # critic_network=critic_net, # actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate), # critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate), # ou_stddev=0.0, # ou_damping=0.0) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), tf_metrics.MaxReturnMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) print(tf_env.envs[0]._gym_env.best_controllers) print(tf_env.envs[0]._gym_env.best_reward) tf_env.envs[0]._gym_env.reset() centroid_controllers, heuristic_distance = tf_env.envs[ 0]._gym_env.graphCentroidAction() # Convert heuristic controllers to actual print(centroid_controllers) # Assume all clusters same length #centroid_controllers.sort() #cluster_len = len(clusters[0]) #for i in range(len(clusters)): # centroid_controllers[i] -= i * cluster_len print(centroid_controllers) for cont in centroid_controllers: (_, reward_final, _, _) = tf_env.envs[0]._gym_env.step(cont) best_heuristic = reward_final print(tf_env.envs[0]._gym_env.controllers, reward_final) return train_loss