def _restore_checkpoint(self): global_step = get_global_counter() checkpointer = tfa_common.Checkpointer( ckpt_dir=os.path.join(self._train_dir, 'algorithm'), algorithm=self._algorithm, metrics=metric_utils.MetricsGroup(self._driver.get_metrics(), 'metrics'), global_step=global_step) checkpointer.initialize_or_restore() self._checkpointer = checkpointer
def create_checkpoints(agent, global_step, checkpoint_dir, train_metrics, eval_metrics, max_to_keep=2): train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(checkpoint_dir, 'collect_policy'), max_to_keep=max_to_keep, agent=agent, policy=agent.collect_policy, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(checkpoint_dir, 'policy'), max_to_keep=max_to_keep, agent=agent, policy=agent.policy, global_step=global_step, metrics=metric_utils.MetricsGroup(eval_metrics, 'eval_metrics')) return train_checkpointer, policy_checkpointer
def _create_checkpointer(ckpt_dir, member, ckpt='train'): if ckpt == 'train': return common.Checkpointer( ckpt_dir=ckpt_dir, agent=member.agent, global_step=member.step_metrics[FP.IDX_ENV_STEPS], metrics=metric_utils.MetricsGroup( member.step_metrics + member.train_metrics, 'train_metrics')) elif ckpt == 'policy': return common.Checkpointer( ckpt_dir=os.path.join(ckpt_dir, 'policy'), policy=member.agent.policy, global_step=member.step_metrics[FP.IDX_ENV_STEPS]) else: raise ValueError
def train( root_dir, load_root_dir=None, env_load_fn=None, env_name=None, num_parallel_environments=1, # pylint: disable=unused-argument agent_class=None, initial_collect_random=True, # pylint: disable=unused-argument initial_collect_driver_class=None, collect_driver_class=None, num_global_steps=1000000, train_steps_per_iteration=1, train_metrics=None, # Safety Critic training args train_sc_steps=10, train_sc_interval=300, online_critic=False, # Params for eval run_eval=False, num_eval_episodes=30, eval_interval=1000, eval_metrics_callback=None, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, keep_rb_checkpoint=False, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, early_termination_fn=None, env_metric_factories=None): # pylint: disable=unused-argument """A simple train and eval for SC-SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() train_metrics = train_metrics or [] if run_eval: eval_dir = os.path.join(root_dir, 'eval') eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = env_load_fn(env_name) if not isinstance(tf_env, tf_py_environment.TFPyEnvironment): tf_env = tf_py_environment.TFPyEnvironment(tf_env) if run_eval: eval_py_env = env_load_fn(env_name) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() print('obs spec:', observation_spec) print('action spec:', action_spec) if online_critic: resample_metric = tf_py_metric.TfPyMetric( py_metrics.CounterMetric('unsafe_ac_samples')) tf_agent = agent_class(time_step_spec, action_spec, train_step_counter=global_step, resample_metric=resample_metric) else: tf_agent = agent_class(time_step_spec, action_spec, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. collect_data_spec = tf_agent.collect_data_spec logging.info('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, max_length=1000000) logging.info('RB capacity: %i', replay_buffer.capacity) logging.info('ReplayBuffer Collect data spec: %s', collect_data_spec) agent_observers = [replay_buffer.add_batch] if online_critic: online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, max_length=10000) online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer') online_rb_checkpointer = common.Checkpointer( ckpt_dir=online_rb_ckpt_dir, max_to_keep=1, replay_buffer=online_replay_buffer) clear_rb = common.function(online_replay_buffer.clear) agent_observers.append(online_replay_buffer.add_batch) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] if not online_critic: eval_policy = tf_agent.policy else: eval_policy = tf_agent._safe_policy # pylint: disable=protected-access initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) if not online_critic: collect_policy = tf_agent.collect_policy else: collect_policy = tf_agent._safe_policy # pylint: disable=protected-access train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) safety_critic_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'safety_critic'), safety_critic=tf_agent._safety_critic_network, # pylint: disable=protected-access global_step=global_step) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) if load_root_dir: load_root_dir = os.path.expanduser(load_root_dir) load_train_dir = os.path.join(load_root_dir, 'train') misc.load_pi_ckpt(load_train_dir, tf_agent) # loads tf_agent if load_root_dir is None: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() safety_critic_checkpointer.initialize_or_restore() collect_driver = collect_driver_class(tf_env, collect_policy, observers=agent_observers + train_metrics) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if not rb_checkpointer.checkpoint_exists: logging.info('Performing initial collection ...') common.function( initial_collect_driver_class(tf_env, initial_collect_policy, observers=agent_observers + train_metrics).run)() last_id = replay_buffer._get_last_id() # pylint: disable=protected-access logging.info('Data saved after initial collection: %d steps', last_id) tf.print( replay_buffer._get_rows_for_id(last_id), # pylint: disable=protected-access output_stream=logging.info) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if FLAGS.viz_pm: eval_fig_dir = osp.join(eval_dir, 'figs') if not tf.io.gfile.isdir(eval_fig_dir): tf.io.gfile.makedirs(eval_fig_dir) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2).prefetch(3) iterator = iter(dataset) if online_critic: online_dataset = online_replay_buffer.as_dataset( num_parallel_calls=3, num_steps=2).prefetch(3) online_iterator = iter(online_dataset) @common.function def critic_train_step(): """Builds critic training step.""" experience, buf_info = next(online_iterator) if env_name in [ 'IndianWell', 'IndianWell2', 'IndianWell3', 'DrunkSpider', 'DrunkSpiderShort' ]: safe_rew = experience.observation['task_agn_rew'] else: safe_rew = agents.process_replay_buffer( online_replay_buffer, as_tensor=True) safe_rew = tf.gather(safe_rew, tf.squeeze(buf_info.ids), axis=1) ret = tf_agent.train_sc(experience, safe_rew) clear_rb() return ret @common.function def train_step(): experience, _ = next(iterator) ret = tf_agent.train(experience) return ret if not early_termination_fn: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss') if online_critic: mean_resample_ac = tf.keras.metrics.Mean( name='mean_unsafe_ac_samples') resample_metric.reset() while (global_step.numpy() <= num_global_steps and not early_termination_fn()): # Collect and train. start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) if online_critic: mean_resample_ac(resample_metric.result()) resample_metric.reset() if time_step.is_last(): resample_ac_freq = mean_resample_ac.result() mean_resample_ac.reset_states() tf.compat.v2.summary.scalar(name='unsafe_ac_samples', data=resample_ac_freq, step=global_step) for _ in range(train_steps_per_iteration): train_loss = train_step() mean_train_loss(train_loss.loss) if online_critic: if global_step.numpy() % train_sc_interval == 0: for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # pylint: disable=unused-variable total_loss = mean_train_loss.result() mean_train_loss.reset_states() # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True break else: loss_divergence_counter = 0 time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), total_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) safety_critic_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: if online_critic: online_rb_checkpointer.save(global_step=global_step_val) rb_checkpointer.save(global_step=global_step_val) if run_eval and global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if FLAGS.viz_pm: savepath = 'step{}.png'.format(global_step_val) savepath = osp.join(eval_fig_dir, savepath) misc.record_episode_vis_summary(eval_tf_env, eval_policy, savepath) if not keep_rb_checkpoint: misc.cleanup_checkpoints(rb_ckpt_dir) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy())) return total_loss
def train_eval( root_dir, env_name='HalfCheetah-v2', env_load_fn=suite_mujoco.load, random_seed=None, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, lstm_size=(20, ), # Params for collect num_environment_steps=25000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-3, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=500, policy_checkpoint_interval=500, log_interval=50, summary_interval=50, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=False, summarize_grads_and_vars=False): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None, lstm_size=lstm_size) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers, activation_fn=tf.keras.activations.tanh) tf_agent = ppo_clip_agent.PPOClipAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, entropy_regularization=0.0, importance_ratio_clipping=0.2, normalize_observations=False, normalize_rewards=False, use_gae=True, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric( batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) train_checkpointer.initialize_or_restore() collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) def train_step(): trajectories = replay_buffer.gather_all() return tf_agent.train(experience=trajectories) if use_tf_functions: # TODO(b/123828980): Enable once the cause for slowdown was identified. collect_driver.run = common.function(collect_driver.run, autograph=False) tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) start_time = time.time() collect_driver.run() collect_time += time.time() - start_time start_time = time.time() total_loss, _ = train_step() replay_buffer.clear() train_time += time.time() - start_time for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = %.3f, train_time = %.3f', collect_time, train_time) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) timed_at_step = global_step_val collect_time = 0 train_time = 0 # One final eval before exiting. metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', )
def load_agents_and_create_videos(root_dir, env_name='CartPole-v0', num_iterations=NUM_ITERATIONS, max_ep_steps=1000, train_sequence_length=1, # Params for QNetwork fc_layer_params=((128,64,32)), # Params for QRnnNetwork input_fc_layer_params=(50,), lstm_size=(20,), output_fc_layer_params=(20,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=10000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, num_random_episodes=1, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None, random_metrics_callback=None): train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') random_dir = os.path.join(root_dir, 'random') eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)] global_step = tf.compat.v1.train.get_or_create_global_step() # Match the environments used in training tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name, max_episode_steps=max_ep_steps)) eval_py_env = suite_gym.load(env_name, max_episode_steps=max_ep_steps) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) if train_sequence_length != 1 and n_step_update != 1: raise NotImplementedError( 'train_eval does not currently support n-step updates with stateful ' 'networks (i.e., RNNs)') if train_sequence_length > 1: q_net = q_rnn_network.QRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) else: q_net = q_network.QNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=fc_layer_params) train_sequence_length = n_step_update # Match the agents used in training tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(),] eval_policy = tf_agent.policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) # Load the data from training train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() # Define a random policy for comparison random_policy = random_tf_policy.RandomTFPolicy(eval_tf_env.time_step_spec(), eval_tf_env.action_spec()) # Make movies of the trained agent and a random agent date_string = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S') trained_filename = "trained-agent" + date_string create_policy_eval_video(eval_tf_env, eval_py_env, tf_agent.policy, trained_filename) random_filename = 'random-agent ' + date_string create_policy_eval_video(eval_tf_env, eval_py_env, random_policy, random_filename)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, train_sequence_length=1, # Params for QNetwork fc_layer_params=(100, ), # Params for QRnnNetwork input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) if train_sequence_length != 1 and n_step_update != 1: raise NotImplementedError( 'train_eval does not currently support n-step updates with stateful ' 'networks (i.e., RNNs)') if train_sequence_length > 1: q_net = q_rnn_network.QRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) else: q_net = q_network.QNetwork(tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=fc_layer_params) train_sequence_length = n_step_update # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def __init__( self, root_dir, env_name, num_iterations=200, max_episode_frames=108000, # ALE frames terminal_on_life_loss=False, conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)), fc_layer_params=(512, ), # Params for collect initial_collect_steps=80000, # ALE frames epsilon_greedy=0.01, epsilon_decay_period=1000000, # ALE frames replay_buffer_capacity=1000000, # Params for train train_steps_per_iteration=1000000, # ALE frames update_period=16, # ALE frames target_update_tau=1.0, target_update_period=32000, # ALE frames batch_size=32, learning_rate=2.5e-4, n_step_update=2, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval do_eval=True, eval_steps_per_iteration=500000, # ALE frames eval_epsilon_greedy=0.001, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=True, eval_metrics_callback=None): """A simple Atari train and eval for DQN. Args: root_dir: Directory to write log files to. env_name: Fully-qualified name of the Atari environment (i.e. Pong-v0). num_iterations: Number of train/eval iterations to run. max_episode_frames: Maximum length of a single episode, in ALE frames. terminal_on_life_loss: Whether to simulate an episode termination when a life is lost. conv_layer_params: Params for convolutional layers of QNetwork. fc_layer_params: Params for fully connected layers of QNetwork. initial_collect_steps: Number of frames to ALE frames to process before beginning to train. Since this is in ALE frames, there will be initial_collect_steps/4 items in the replay buffer when training starts. epsilon_greedy: Final epsilon value to decay to for training. epsilon_decay_period: Period over which to decay epsilon, from 1.0 to epsilon_greedy (defined above). replay_buffer_capacity: Maximum number of items to store in the replay buffer. train_steps_per_iteration: Number of ALE frames to run through for each iteration of training. update_period: Run a train operation every update_period ALE frames. target_update_tau: Coeffecient for soft target network updates (1.0 == hard updates). target_update_period: Period, in ALE frames, to copy the live network to the target network. batch_size: Number of frames to include in each training batch. learning_rate: RMS optimizer learning rate. n_step_update: The number of steps to consider when computing TD error and TD loss. Applies standard single-step updates when set to 1. gamma: Discount for future rewards. reward_scale_factor: Scaling factor for rewards. gradient_clipping: Norm length to clip gradients. do_eval: If True, run an eval every iteration. If False, skip eval. eval_steps_per_iteration: Number of ALE frames to run through for each iteration of evaluation. eval_epsilon_greedy: Epsilon value to use for the evaluation policy (0 == totally greedy policy). log_interval: Log stats to the terminal every log_interval training steps. summary_interval: Write TF summaries every summary_interval training steps. summaries_flush_secs: Flush summaries to disk every summaries_flush_secs seconds. debug_summaries: If True, write additional summaries for debugging (see dqn_agent for which summaries are written). summarize_grads_and_vars: Include gradients in summaries. eval_metrics_callback: A callback function that takes (metric_dict, global_step) as parameters. Called after every eval with the results of the evaluation. """ self._update_period = update_period / ATARI_FRAME_SKIP self._train_steps_per_iteration = (train_steps_per_iteration / ATARI_FRAME_SKIP) self._do_eval = do_eval self._eval_steps_per_iteration = eval_steps_per_iteration / ATARI_FRAME_SKIP self._eval_epsilon_greedy = eval_epsilon_greedy self._initial_collect_steps = initial_collect_steps / ATARI_FRAME_SKIP self._summary_interval = summary_interval self._num_iterations = num_iterations self._log_interval = log_interval self._eval_metrics_callback = eval_metrics_callback with gin.unlock_config(): gin.bind_parameter(('tf_agents.environments.atari_preprocessing.' 'AtariPreprocessing.terminal_on_life_loss'), terminal_on_life_loss) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() self._train_summary_writer = train_summary_writer self._eval_summary_writer = None if self._do_eval: self._eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) self._eval_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if(lambda: tf.math.equal( self._global_step % self._summary_interval, 0)): self._env = suite_atari.load( env_name, max_episode_steps=max_episode_frames / ATARI_FRAME_SKIP, gym_env_wrappers=suite_atari. DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) self._env = batched_py_environment.BatchedPyEnvironment( [self._env]) observation_spec = tensor_spec.from_spec( self._env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.from_spec(self._env.action_spec()) with tf.device('/cpu:0'): epsilon = tf.compat.v1.train.polynomial_decay( 1.0, self._global_step, epsilon_decay_period / ATARI_FRAME_SKIP / self._update_period, end_learning_rate=epsilon_greedy) with tf.device('/gpu:0'): optimizer = tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.0, epsilon=0.00001, centered=True) categorical_q_net = AtariCategoricalQNetwork( observation_spec, action_spec, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params) agent = categorical_dqn_agent.CategoricalDqnAgent( time_step_spec, action_spec, categorical_q_network=categorical_q_net, optimizer=optimizer, epsilon_greedy=epsilon, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=(target_update_period / ATARI_FRAME_SKIP / self._update_period), gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=self._global_step) self._collect_policy = py_tf_policy.PyTFPolicy( agent.collect_policy) if self._do_eval: self._eval_policy = py_tf_policy.PyTFPolicy( epsilon_greedy_policy.EpsilonGreedyPolicy( policy=agent.policy, epsilon=self._eval_epsilon_greedy)) py_observation_spec = self._env.observation_spec() py_time_step_spec = ts.time_step_spec(py_observation_spec) py_action_spec = policy_step.PolicyStep( self._env.action_spec()) data_spec = trajectory.from_transition(py_time_step_spec, py_action_spec, py_time_step_spec) self._replay_buffer = py_hashed_replay_buffer.PyHashedReplayBuffer( data_spec=data_spec, capacity=replay_buffer_capacity) with tf.device('/cpu:0'): ds = self._replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=n_step_update + 1) ds = ds.prefetch(4) ds = ds.apply( tf.data.experimental.prefetch_to_device('/gpu:0')) with tf.device('/gpu:0'): self._ds_itr = tf.compat.v1.data.make_one_shot_iterator(ds) experience = self._ds_itr.get_next() self._train_op = agent.train(experience) self._env_steps_metric = py_metrics.EnvironmentSteps() self._step_metrics = [ py_metrics.NumberOfEpisodes(), self._env_steps_metric, ] self._train_metrics = self._step_metrics + [ py_metrics.AverageReturnMetric(buffer_size=10), py_metrics.AverageEpisodeLengthMetric(buffer_size=10), ] # The _train_phase_metrics average over an entire train iteration, # rather than the rolling average of the last 10 episodes. self._train_phase_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._iteration_metric = py_metrics.CounterMetric( name='Iteration') # Summaries written from python should run every time they are # generated. with tf.compat.v2.summary.record_if(True): self._steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') self._steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=self._steps_per_second_ph, step=self._global_step) for metric in self._train_metrics: metric.tf_summaries(train_step=self._global_step, step_metrics=self._step_metrics) for metric in self._train_phase_metrics: metric.tf_summaries( train_step=self._global_step, step_metrics=(self._iteration_metric, )) self._iteration_metric.tf_summaries( train_step=self._global_step) if self._do_eval: with self._eval_summary_writer.as_default(): for metric in self._eval_metrics: metric.tf_summaries( train_step=self._global_step, step_metrics=(self._iteration_metric, )) self._train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=agent, global_step=self._global_step, optimizer=optimizer, metrics=metric_utils.MetricsGroup( self._train_metrics + self._train_phase_metrics + [self._iteration_metric], 'train_metrics')) self._policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=agent.policy, global_step=self._global_step) self._rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=self._replay_buffer) self._init_agent_op = agent.initialize()
def run(): tf_env = tf_py_environment.TFPyEnvironment(SnakeEnv()) eval_env = tf_py_environment.TFPyEnvironment(SnakeEnv(step_limit=50)) q_net = q_network.QNetwork( tf_env.observation_spec(), tf_env.action_spec(), conv_layer_params=(), fc_layer_params=(512, 256, 128), ) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) global_counter = tf.compat.v1.train.get_or_create_global_step() agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=optimizer, td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=global_counter, gamma=0.95, epsilon_greedy=0.1, n_step_update=1, ) root_dir = os.path.join('/tf-logs', 'snake') train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_max_length, ) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, agent.collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration, ) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=agent, global_step=global_counter, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=agent.policy, global_step=global_counter, ) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer, ) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() collect_driver.run = common.function(collect_driver.run) agent.train = common.function(agent.train) random_policy = random_tf_policy.RandomTFPolicy(tf_env.time_step_spec(), tf_env.action_spec()) if replay_buffer.num_frames() >= initial_collect_steps: logging.info("We loaded memories, not doing random seed") else: logging.info("Capturing %d steps to seed with random memories", initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, random_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() train_summary_writer = tf.summary.create_file_writer(train_dir) train_summary_writer.set_as_default() avg_returns = [] avg_return_metric = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes) eval_metrics = [ avg_return_metric, tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] logging.info("Running initial evaluation") results = metric_utils.eager_compute( eval_metrics, eval_env, agent.policy, num_episodes=num_eval_episodes, train_step=global_counter, summary_writer=tf.summary.create_file_writer(eval_dir), summary_prefix='Metrics', ) avg_returns.append( (global_counter.numpy(), avg_return_metric.result().numpy())) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_counter.numpy() time_acc = 0 dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) @common.function def train_step(): experience, _ = next(iterator) return agent.train(experience) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time step = global_counter.numpy() if step % log_interval == 0: logging.info("step = %d, loss = %f", step, train_loss.loss) steps_per_sec = (step - timed_at_step) / time_acc logging.info("%.3f steps/sec", steps_per_sec) timed_at_step = step time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_counter, step_metrics=train_metrics[:2]) if step % train_checkpoint_interval == 0: train_checkpointer.save(global_step=step) if step % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=step) if step % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=step) if step % capture_interval == 0: print("Capturing run:") capture_run(os.path.join(root_dir, "snake" + str(step) + ".mp4"), eval_env, agent.policy) if step % eval_interval == 0: print("EVALUTION TIME:") results = metric_utils.eager_compute( eval_metrics, eval_env, agent.policy, num_episodes=num_eval_episodes, train_step=global_counter, summary_writer=tf.summary.create_file_writer(eval_dir), summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) avg_returns.append( (global_counter.numpy(), avg_return_metric.result().numpy()))
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=5e5, train_sequence_length=1, # Params for QNetwork fc_layer_params=( 64, 64, ), # Params for QRnnNetwork input_fc_layer_params=(50, ), lstm_size=(6, ), output_fc_layer_params=(30, ), # Params for collect initial_collect_steps=2000, collect_steps_per_iteration=6, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=6, batch_size=32, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=1, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') clusters = pickle.load(open('clusters.pickle', 'rb')) graph = nx.read_gpickle('graph.gpickle') print(graph.nodes) train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name, gym_kwargs={ 'graph': graph, 'clusters': clusters })) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name, gym_kwargs={ 'graph': graph, 'clusters': clusters })) if train_sequence_length != 1 and n_step_update != 1: raise NotImplementedError( 'train_eval does not currently support n-step updates with stateful ' 'networks (i.e., RNNs)') action_spec = tf_env.action_spec() num_actions = action_spec.maximum - action_spec.minimum + 1 if train_sequence_length > 1: q_net = create_recurrent_network(input_fc_layer_params, lstm_size, output_fc_layer_params, num_actions) else: q_net = create_feedforward_network(fc_layer_params, num_actions) train_sequence_length = n_step_update q_net = GATNetwork(tf_env.observation_spec(), tf_env.action_spec(), graph) #time_step = tf_env.reset() #q_net(time_step.observation, time_step.step_type) #q_net = actor_distribution_network.ActorDistributionNetwork( # tf_env.observation_spec(), # tf_env.action_spec(), # fc_layer_params=fc_layer_params) #q_net = QNetwork(tf_env.observation_spec(), tf_env.action_spec(), 30) # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) #critic_net = ddpg.critic_network.CriticNetwork( #(tf_env.observation_spec(), tf_env.action_spec()), #observation_fc_layer_params=None, #action_fc_layer_params=None, #joint_fc_layer_params=(64,64,), #kernel_initializer='glorot_uniform', #last_kernel_initializer='glorot_uniform') #tf_agent = DdpgAgent(tf_env.time_step_spec(), # tf_env.action_spec(), # actor_network=q_net, # critic_network=critic_net, # actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate), # critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate), # ou_stddev=0.0, # ou_damping=0.0) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), tf_metrics.MaxReturnMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) print(tf_env.envs[0]._gym_env.best_controllers) print(tf_env.envs[0]._gym_env.best_reward) tf_env.envs[0]._gym_env.reset() centroid_controllers, heuristic_distance = tf_env.envs[ 0]._gym_env.graphCentroidAction() # Convert heuristic controllers to actual print(centroid_controllers) # Assume all clusters same length #centroid_controllers.sort() #cluster_len = len(clusters[0]) #for i in range(len(clusters)): # centroid_controllers[i] -= i * cluster_len print(centroid_controllers) for cont in centroid_controllers: (_, reward_final, _, _) = tf_env.envs[0]._gym_env.step(cont) best_heuristic = reward_final print(tf_env.envs[0]._gym_env.controllers, reward_final) return train_loss
def train_eval( root_dir, env_name='Blob2d-v1', num_iterations=100000, train_sequence_length=1, collect_steps_per_iteration=1, initial_collect_steps=1500, replay_buffer_max_length=10000, batch_size=64, learning_rate=1e-3, num_eval_episodes=10, eval_interval=1000, # Params for QNetwork fc_layer_params=(100, ), use_tf_functions=False, ## train params train_steps_per_iteration=1, train_checkpoint_interval=1000, policy_checkpoint_interval=1000, rb_checkpoint_interval=1000, n_step_update=1, ## Params for Summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) if train_sequence_length != 1 and n_step_update != 1: raise NotImplementedError( 'train_eval does not currently support n-step updates with stateful ' 'networks (i.e., RNNs)') env = suite_gym.load('Blob2d-v1') tf_env = tf_py_environment.TFPyEnvironment(env) action_spec = tf_env.action_spec() fc_layer_params = (100, ) q_net = q_network.QNetwork(tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=fc_layer_params) agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=global_step) agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = agent.policy collect_policy = agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_max_length) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer(ckpt_dir=train_dir, agent=agent, global_step=global_step, metrics=metric_utils.MetricsGroup( train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return agent.train(experience) if use_tf_functions: train_step = common.function(train_step) # Main Training loop. for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, env_name='sawyer_reach', num_iterations=3000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, gamma=0.99, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=200000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, random_seed=0, max_future_steps=50, actor_std=None, log_subset=None, ): """A simple train and eval for SAC.""" np.random.seed(random_seed) tf.random.set_seed(random_seed) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env, eval_tf_env, obs_dim = c_learning_envs.load(env_name) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() if actor_std is None: proj_net = tanh_normal_projection_network.TanhNormalProjectionNetwork else: proj_net = functools.partial( tanh_normal_projection_network.TanhNormalProjectionNetwork, std_transform=lambda t: actor_std * tf.ones_like(t)) actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=proj_net) critic_net = c_learning_utils.ClassifierCriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = c_learning_agent.CLearningAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=bce_loss, gamma=gamma, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), c_learning_utils.FinalDistance( buffer_size=num_eval_episodes, obs_dim=obs_dim), c_learning_utils.MinimumDistance( buffer_size=num_eval_episodes, obs_dim=obs_dim), c_learning_utils.DeltaDistance( buffer_size=num_eval_episodes, obs_dim=obs_dim), ] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), c_learning_utils.InitialDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, obs_dim=obs_dim), c_learning_utils.FinalDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, obs_dim=obs_dim), c_learning_utils.MinimumDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, obs_dim=obs_dim), c_learning_utils.DeltaDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, obs_dim=obs_dim), ] if log_subset is not None: start_index, end_index = log_subset for name, metrics in [('train', train_metrics), ('eval', eval_metrics)]: metrics.extend([ c_learning_utils.InitialDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size if name == 'train' else 10, obs_dim=obs_dim, start_index=start_index, end_index=end_index, name='SubsetInitialDistance'), c_learning_utils.FinalDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size if name == 'train' else 10, obs_dim=obs_dim, start_index=start_index, end_index=end_index, name='SubsetFinalDistance'), c_learning_utils.MinimumDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size if name == 'train' else 10, obs_dim=obs_dim, start_index=start_index, end_index=end_index, name='SubsetMinimumDistance'), c_learning_utils.DeltaDistance( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size if name == 'train' else 10, obs_dim=obs_dim, start_index=start_index, end_index=end_index, name='SubsetDeltaDistance'), ]) eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=None) train_checkpointer.initialize_or_restore() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Save the hyperparameters operative_filename = os.path.join(root_dir, 'operative.gin') with tf.compat.v1.gfile.Open(operative_filename, 'w') as f: f.write(gin.operative_config_str()) logging.info(gin.operative_config_str()) if replay_buffer.num_frames() == 0: # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps ' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=max_future_steps) dataset = dataset.unbatch().filter(_filter_invalid_transition) dataset = dataset.batch(batch_size, drop_remainder=True) goal_fn = functools.partial( c_learning_utils.goal_fn, batch_size=batch_size, obs_dim=obs_dim, gamma=gamma) dataset = dataset.map(goal_fn) dataset = dataset.prefetch(5) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) global_step_val = global_step.numpy() while global_step_val < num_iterations: start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) return train_loss
def train_eval(root_dir, tf_env, eval_tf_env, agent, num_iterations, initial_collect_steps, collect_steps_per_iteration, replay_buffer_capacity, train_steps_per_iteration, batch_size, use_tf_functions, num_eval_episodes, eval_interval, train_checkpoint_interval, policy_checkpoint_interval, rb_checkpoint_interval, log_interval, summary_interval, summaries_flush_secs): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ #tf_metrics.ChosenActionHistogram(buffer_size=num_eval_episodes), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), #tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_env eval_tf_env = eval_tf_env tf_agent = agent train_metrics = [ #tf_metrics.ChosenActionHistogram(), tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=1), #tf_metrics.AverageEpisodeLengthMetric(), ] diverged = False eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, max_to_keep=1, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, max_to_keep=1, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() best_policy = -1000 if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) #Collect initial replay data. dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if np.isnan(train_loss.loss).any(): diverged = True break elif np.isinf(train_loss.loss).any(): diverged = True break if global_step.numpy() % log_interval == 0: print('step = {0}, loss = {1}'.format(global_step.numpy(), train_loss.loss)) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc print('{0} steps/sec'.format(steps_per_sec)) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if results["AverageReturn"].numpy() > best_policy: print("New best policy found") print(results["AverageReturn"].numpy()) best_policy = results["AverageReturn"].numpy() policy_checkpointer.save(global_step=global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, experiment_name, train_eval_dir=None, universe='gym', env_name='HalfCheetah-v2', domain_name='cheetah', task_name='run', action_repeat=1, num_iterations=int(1e7), actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), model_network_ctor=model_distribution_network.ModelDistributionNetwork, critic_input='state', actor_input='state', compressor_descriptor='preprocessor_32_3', # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=int(1e5), # increase if necessary since buffers with images are huge # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, model_train_steps_per_iteration=1, initial_model_train_steps=100000, batch_size=256, model_batch_size=32, sequence_length=4, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, model_learning_rate=1e-4, td_errors_loss_fn=functools.partial( tf.compat.v1.losses.mean_squared_error, weights=0.5), gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging num_images_per_summary=1, train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=0, # enable if necessary since buffers with images are huge log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, gpu_allow_growth=False, gpu_memory_limit=None): """A simple train and eval for SLAC.""" gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit) ]) if train_eval_dir is None: train_eval_dir = get_train_eval_dir(root_dir, universe, env_name, domain_name, task_name, experiment_name) train_dir = os.path.join(train_eval_dir, 'train') eval_dir = os.path.join(train_eval_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthEvalPolicy', buffer_size=num_eval_episodes), ] eval_greedy_metrics = [ py_metrics.AverageReturnMetric(name='AverageReturnEvalGreedyPolicy', buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthEvalGreedyPolicy', buffer_size=num_eval_episodes), ] eval_summary_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. trainable_model = model_train_steps_per_iteration != 0 state_only = (actor_input == 'state' and critic_input == 'state' and not trainable_model and initial_model_train_steps == 0) # Save time from unnecessarily rendering observations. observations_whitelist = ['state'] if state_only else None py_env, eval_py_env = load_environments( universe, env_name=env_name, domain_name=domain_name, task_name=task_name, observations_whitelist=observations_whitelist, action_repeat=action_repeat) tf_env = tf_py_environment.TFPyEnvironment(py_env, isolation=True) original_control_timestep = get_control_timestep(eval_py_env) control_timestep = original_control_timestep * float(action_repeat) fps = int(np.round(1.0 / control_timestep)) render_fps = int(np.round(1.0 / original_control_timestep)) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() if model_train_steps_per_iteration not in (0, train_steps_per_iteration): raise NotImplementedError model_net = model_network_ctor(observation_spec, action_spec) if compressor_descriptor == 'model': compressor_net = model_net.compressor elif re.match('preprocessor_(\d+)_(\d+)', compressor_descriptor): m = re.match('preprocessor_(\d+)_(\d+)', compressor_descriptor) filters, n_layers = m.groups() filters = int(filters) n_layers = int(n_layers) compressor_net = compressor_network.Preprocessor(filters, n_layers=n_layers) elif re.match('compressor_(\d+)', compressor_descriptor): m = re.match('compressor_(\d+)', compressor_descriptor) filters, = m.groups() filters = int(filters) compressor_net = compressor_network.Compressor(filters) elif re.match('softlearning_(\d+)_(\d+)', compressor_descriptor): m = re.match('softlearning_(\d+)_(\d+)', compressor_descriptor) filters, n_layers = m.groups() filters = int(filters) n_layers = int(n_layers) compressor_net = compressor_network.SoftlearningPreprocessor( filters, n_layers=n_layers) elif compressor_descriptor == 'd4pg': compressor_net = compressor_network.D4pgPreprocessor() else: raise NotImplementedError(compressor_descriptor) actor_state_size = 0 for _actor_input in actor_input.split('__'): if _actor_input == 'state': state_size, = observation_spec['state'].shape actor_state_size += state_size elif _actor_input == 'latent': actor_state_size += model_net.state_size elif _actor_input == 'feature': actor_state_size += compressor_net.feature_size elif _actor_input in ('sequence_feature', 'sequence_action_feature'): actor_state_size += compressor_net.feature_size * sequence_length if _actor_input == 'sequence_action_feature': actor_state_size += tf.compat.dimension_value( action_spec.shape[0]) * (sequence_length - 1) else: raise NotImplementedError actor_input_spec = tensor_spec.TensorSpec((actor_state_size, ), dtype=tf.float32) critic_state_size = 0 for _critic_input in critic_input.split('__'): if _critic_input == 'state': state_size, = observation_spec['state'].shape critic_state_size += state_size elif _critic_input == 'latent': critic_state_size += model_net.state_size elif _critic_input == 'feature': critic_state_size += compressor_net.feature_size elif _critic_input in ('sequence_feature', 'sequence_action_feature'): critic_state_size += compressor_net.feature_size * sequence_length if _critic_input == 'sequence_action_feature': critic_state_size += tf.compat.dimension_value( action_spec.shape[0]) * (sequence_length - 1) else: raise NotImplementedError critic_input_spec = tensor_spec.TensorSpec((critic_state_size, ), dtype=tf.float32) actor_net = actor_distribution_network.ActorDistributionNetwork( actor_input_spec, action_spec, fc_layer_params=actor_fc_layers) critic_net = critic_network.CriticNetwork( (critic_input_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = slac_agent.SlacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, model_network=model_net, compressor_network=compressor_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), sequence_length=sequence_length, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, trainable_model=trainable_model, critic_input=critic_input, actor_input=actor_input, model_batch_size=model_batch_size, control_timestep=control_timestep, num_images_per_summary=num_images_per_summary, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) eval_greedy_py_policy = py_tf_policy.PyTFPolicy( greedy_policy.GreedyPolicy(tf_agent.policy)) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric( py_metrics.AverageReturnMetric(buffer_size=1)), tf_py_metric.TFPyMetric( py_metrics.AverageEpisodeLengthMetric(buffer_size=1)), ] collect_policy = tf_agent.collect_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) initial_policy_state = initial_collect_policy.get_initial_state(1) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run( policy_state=initial_policy_state) policy_state = collect_policy.get_initial_state(1) collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run( policy_state=policy_state) # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[-2] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=sequence_length + 1).unbatch().filter(_filter_invalid_transition).batch( batch_size, drop_remainder=True).prefetch(3) dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) if initial_model_train_steps: with tf.name_scope('initial'): model_train_op = tf_agent.train_model(trajectories) model_summary_ops = [] for summary_op in tf.compat.v1.summary.all_v2_summary_ops(): if summary_op not in summary_ops: model_summary_ops.append(summary_op) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics + eval_greedy_metrics: eval_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if eval_interval: eval_images_ph = tf.compat.v1.placeholder(dtype=tf.uint8, shape=[None] * 5) eval_images_summary = gif_utils.gif_summary_v2( 'ObservationVideoEvalPolicy', eval_images_ph, 1, fps) eval_render_images_summary = gif_utils.gif_summary_v2( 'VideoEvalPolicy', eval_images_ph, 1, render_fps) eval_greedy_images_summary = gif_utils.gif_summary_v2( 'ObservationVideoEvalGreedyPolicy', eval_images_ph, 1, fps) eval_greedy_render_images_summary = gif_utils.gif_summary_v2( 'VideoEvalGreedyPolicy', eval_images_ph, 1, render_fps) train_config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=False) eval_config_saver = gin.tf.GinConfigSaverHook(eval_dir, summarize_config=False) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=2) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step, max_to_keep=2) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) train_config_saver.after_create_session(sess) eval_config_saver.after_create_session(sess) global_step_val = sess.run(global_step) if global_step_val == 0: if eval_interval: # Initial eval of randomly initialized policy for _eval_metrics, _eval_py_policy, \ _eval_render_images_summary, _eval_images_summary in ( (eval_metrics, eval_py_policy, eval_render_images_summary, eval_images_summary), (eval_greedy_metrics, eval_greedy_py_policy, eval_greedy_render_images_summary, eval_greedy_images_summary)): compute_summaries( _eval_metrics, eval_py_env, _eval_py_policy, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, images_ph=eval_images_ph, render_images_summary=_eval_render_images_summary, images_summary=_eval_images_summary) sess.run(eval_summary_flush_op) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) policy_state_val = sess.run(policy_state) collect_call = sess.make_callable(collect_op, feed_list=[policy_state]) train_step_call = sess.make_callable([train_op, summary_ops]) if initial_model_train_steps: model_train_step_call = sess.make_callable( [model_train_op, model_summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') # steps_per_second summary should always be recorded since it's only called every log_interval steps with tf.compat.v2.summary.record_if(True): steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for iteration in range(global_step_val, initial_model_train_steps + num_iterations): start_time = time.time() if iteration < initial_model_train_steps: total_loss_val, _ = model_train_step_call() else: time_step_val, policy_state_val = collect_call( policy_state_val) for _ in range(train_steps_per_iteration): total_loss_val, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if log_interval and global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss_val.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if (train_checkpoint_interval and global_step_val % train_checkpoint_interval == 0): train_checkpointer.save(global_step=global_step_val) if iteration < initial_model_train_steps: continue if eval_interval and global_step_val % eval_interval == 0: for _eval_metrics, _eval_py_policy, \ _eval_render_images_summary, _eval_images_summary in ( (eval_metrics, eval_py_policy, eval_render_images_summary, eval_images_summary), (eval_greedy_metrics, eval_greedy_py_policy, eval_greedy_render_images_summary, eval_greedy_images_summary)): compute_summaries( _eval_metrics, eval_py_env, _eval_py_policy, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, images_ph=eval_images_ph, render_images_summary=_eval_render_images_summary, images_summary=_eval_images_summary) sess.run(eval_summary_flush_op) if (policy_checkpoint_interval and global_step_val % policy_checkpoint_interval == 0): policy_checkpointer.save(global_step=global_step_val) if (rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0): rb_checkpointer.save(global_step=global_step_val)
def __init__(self): self._train_py_env = suite_gym.load(T48GymEnv.GYM_ENV_NAME, max_episode_steps=T48GymTensorflowContext.max_episode_steps) self._eval_py_env = suite_gym.load(T48GymEnv.GYM_ENV_NAME, max_episode_steps=T48GymTensorflowContext.max_episode_steps) self._train_env = tf_py_environment.TFPyEnvironment(self._train_py_env) self._eval_env = tf_py_environment.TFPyEnvironment(self._eval_py_env) self._global_step = tf.compat.v1.train.get_or_create_global_step() self._q_net = q_network.QNetwork( self._train_env.observation_spec(), self._train_env.action_spec(), fc_layer_params=(100,)) self._agent = dqn_agent.DdqnAgent( self._train_env.time_step_spec(), self._train_env.action_spec(), q_network=self._q_net, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=T48GymTensorflowContext.learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, train_step_counter=self._global_step, epsilon_greedy=0.0) self._agent.initialize() self._agent.train = common.function(self._agent.train) self._replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=self._agent.collect_data_spec, batch_size=self._train_env.batch_size, max_length=T48GymTensorflowContext.replay_buffer_max_length) self._dataset = self._replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=self._train_env.batch_size, num_steps=2).prefetch(3) self._agent.initialize() self._iterator = iter(self._dataset) self._RANDOM_POLICY = random_tf_policy.RandomTFPolicy(self._train_env.time_step_spec(), self._train_env.action_spec()) self._collect_policy = self._agent.collect_policy self._eval_policy = self._agent.policy self._collect_driver = dynamic_step_driver.DynamicStepDriver( self._train_env, self._collect_policy, observers=[self._replay_buffer.add_batch] + T48GymTensorflowContext.train_metrics, num_steps=2) self._train_checkpointer = common.Checkpointer( ckpt_dir=T48GymTensorflowContext.train_dir, global_step=self._global_step, agent=self._agent, metrics=metric_utils.MetricsGroup(T48GymTensorflowContext.train_metrics, 'train_metrics')) self._policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(T48GymTensorflowContext.train_dir, 'policy'), global_step=self._global_step, policy=self._eval_policy) self._rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(T48GymTensorflowContext.train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=self._replay_buffer) self._tf_policy_saver = policy_saver.PolicySaver(self._agent.policy) self._train_checkpointer.initialize_or_restore() self._policy_checkpointer.initialize_or_restore() self._rb_checkpointer.initialize_or_restore()
def train(self, training_iterations=TRAINING_ITERATIONS, training_stock_list=None): self.reset(training_stock_list) train_dir = 'training_data_progress/train-' + self.name eval_dir = 'training_data_progress/eval-' + self.name replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=self.tf_agent.collect_data_spec, batch_size=self.tf_training_env.batch_size, max_length=MAX_BUFFER_SIZE) summaries_flush_secs = 10 eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=NUM_EVAL_EPISODES), tf_metrics.AverageEpisodeLengthMetric( buffer_size=NUM_EVAL_EPISODES) ] global_step = self.tf_agent.train_step_counter with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % LOG_INTERVAL, 0)): replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric( buffer_size=NUM_EVAL_EPISODES, batch_size=self.tf_training_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=NUM_EVAL_EPISODES, batch_size=self.tf_training_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(self.tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( self.tf_training_env.time_step_spec(), self.tf_training_env.action_spec()) collect_policy = self.tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=self.tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver_random = dynamic_step_driver.DynamicStepDriver( self.tf_training_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=INIT_COLLECT_STEPS) initial_collect_driver_random.run = common.function( initial_collect_driver_random.run) collect_driver = dynamic_step_driver.DynamicStepDriver( self.tf_training_env, collect_policy, observers=replay_observer + train_metrics, num_steps=STEP_ITERATIONS) collect_driver.run = common.function(collect_driver.run) self.tf_agent.train = common.function(self.tf_agent.train) # Collect some initial data. # Random random_policy = random_tf_policy.RandomTFPolicy( self.tf_training_env.time_step_spec(), self.tf_training_env.action_spec()) avg_return, avg_return_per_step, avg_daily_percentage = self.compute_avg_return( random_policy) print( 'Random:\n\tAverage Return = {0}\n\tAverage Return Per Step = {1}\n\tPercent = {2}%' .format(avg_return, avg_return_per_step, avg_daily_percentage)) self.gym_training_env.save_feature_distribution(self.name) # Agent avg_return, avg_return_per_step, avg_daily_percentage = self.compute_avg_return( self.tf_agent.policy) print( 'Agent :\n\tAverage Return = {0}\n\tAverage Return Per Step = {1}\n\tPercent = {2}%' .format(avg_return, avg_return_per_step, avg_daily_percentage)) self.eval_env.reset() self.eval_env.run_and_save_evaluation(str(0)) self.gym_training_env.save_feature_distribution(self.name) evaluations = [self.get_evaluation()] returns = [self.eval_env.returns] actions_over_time_list = [self.eval_env.action_sets_over_time] # Collect initial replay data. print( 'Initializing replay buffer by collecting experience for {} steps with ' 'a random policy.'.format(INIT_COLLECT_STEPS)) initial_collect_driver_random.run() results = metric_utils.eager_compute( eval_metrics, self.tf_training_env, eval_policy, num_episodes=NUM_EVAL_EPISODES, train_step=global_step, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state( self.tf_training_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=BATCH_SIZE, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(BATCH_SIZE).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def _train_step(): try: experience, _ = next(iterator) return self.tf_agent.train(experience) except Exception as e: print("Caught Exception:", e) return 1e-20 train_step = common.function(_train_step) for _ in range(training_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(STEP_ITERATIONS): train_loss = train_step() time_acc += time.time() - start_time self.global_step_val = global_step.numpy() if self.global_step_val % LOG_INTERVAL == 0: steps_per_sec = (self.global_step_val - timed_at_step) / time_acc print( self.name, '\nstep = {0:d}:\n\tloss = {1:f}\n\t{2:.3f} steps/sec'. format(self.global_step_val, train_loss.loss, steps_per_sec)) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = self.global_step_val time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if self.global_step_val % EVAL_INTERVAL == 0: results = metric_utils.eager_compute( eval_metrics, self.tf_training_env, eval_policy, num_episodes=NUM_EVAL_EPISODES, train_step=global_step, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) avg_return, avg_return_per_step, avg_daily_percentage = self.compute_avg_return( self.tf_agent.policy) print( self.name, '\nstep = {0}:\n\tloss = {1}\n\tAverage Return = {2}\n\tAverage Return Per Step = {3}\n\tPercent = {4}%' .format(self.global_step_val, train_loss.loss, avg_return, avg_return_per_step, avg_daily_percentage)) self.eval_env.reset() self.eval_env.run_and_save_evaluation( str(self.global_step_val // EVAL_INTERVAL)) self.gym_training_env.save_feature_distribution(self.name) if avg_daily_percentage == returns[-1]: "---- Average return did not change since last time. Breaking loop." break evaluations.append(self.get_evaluation()) returns.append(self.eval_env.returns) actions_over_time_list.append( self.eval_env.action_sets_over_time) train_checkpointer.save(global_step=self.global_step_val) policy_checkpointer.save(global_step=self.global_step_val) rb_checkpointer.save(global_step=self.global_step_val) training_report = util.load_training_report() agent_report = training_report.get(self.name, dict()) agent_report["Training Results"] = returns agent_report["Evaluations"] = [max(e, 0.0) for e in evaluations] bins = [0.1 * i - 0.0000001 for i in range(11)] agent_report["Histograms"] = [ str(list(map(int, np.histogram(actions, bins, density=True)[0]))) for actions in actions_over_time_list ] training_report[self.name] = agent_report util.save_training_report(training_report) print("---- Average-daily-percentage over training period for", self.name) print("\t\t", avg_daily_percentage) self.save() self.reset()
def train_eval( root_dir, experiment_name, # experiment name env_name='carla-v0', num_iterations=int(1e7), model_network_ctor_type='non-hierarchical', # model net input_names=['camera', 'lidar'], # names for inputs reconstruct_names=['roadmap'], # names for masks pixor_names=['vh_clas', 'vh_regr', 'pixor_state'], # names for pixor outputs reconstruct_pixor_state=True, # whether to reconstruct pixor_state extra_names=['state'], # extra inputs obs_size=64, # size of observation image pixor_size=64, # size of pixor output image perception_weight=1.0, # weight of perception part loss # Params for collect initial_collect_steps=1000, replay_buffer_capacity=int(5e4+1), # Params for train training=True, # whether to train, or just evaluate model_batch_size=32, # model training batch size sequence_length=10, # number of timesteps to train model model_learning_rate=1e-4, # learning rate for model training gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=2000, # Params for summaries and logging num_images_per_summary=1, # images for each summary train_checkpoint_interval=2000, log_interval=200, summary_interval=2000, summaries_flush_secs=10, summarize_grads_and_vars=False, gpu_allow_growth=True, # GPU memory growth gpu_memory_limit=None, # GPU memory limit action_repeat=1): # Name of single observation channel, ['camera', 'lidar', 'birdeye'] """A simple train and eval for SLAC.""" # Setup GPU gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit)]) # Get train and eval direction root_dir = os.path.expanduser(root_dir) root_dir = os.path.join(root_dir, env_name, experiment_name) # Get summary writers summary_writer = tf.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() # Eval metrics eval_metrics = [ tf_metrics.AverageReturnMetric( name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthEvalPolicy', buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() # Whether to record for summary with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create Carla environment py_env, eval_py_env = load_carla_env(env_name='carla-v0', lidar_bin=32/obs_size, pixor_size=pixor_size, obs_channels=list(set(input_names+reconstruct_names+pixor_names+extra_names)), action_repeat=action_repeat) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) fps = int(np.round(1.0 / (py_env.dt * action_repeat))) # Specs time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() # Get model network if model_network_ctor_type == 'hierarchical': model_network_ctor = sequential_latent_pixor_network.PixorSLMHierarchical else: raise NotImplementedError model_net = model_network_ctor( input_names, reconstruct_names, obs_size=obs_size, pixor_size=pixor_size, reconstruct_pixor_state=reconstruct_pixor_state, perception_weight=perception_weight) # Build the perception agent actor_network = state_based_heuristic_actor_network.StateBasedHeuristicActorNetwork( observation_spec['state'], action_spec, desired_speed=9 ) tf_agent = perception_agent.PerceptionAgent( time_step_spec, action_spec, actor_network=actor_network, model_network=model_net, model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), num_images_per_summary=num_images_per_summary, sequence_length=sequence_length, gradient_clipping=gradient_clipping, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, fps=fps) tf_agent.initialize() # Train metrics env_steps = tf_metrics.EnvironmentSteps() average_return = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] # Get policies eval_policy = tf_agent.policy initial_collect_policy = tf_agent.collect_policy # Checkpointers train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=2) train_checkpointer.initialize_or_restore() model_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'model'), model=model_net, max_to_keep=2) # Evaluation compute_summaries( eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, model_net=model_net, fps=10, image_keys=['camera', 'lidar', 'roadmap'], pixor_size=pixor_size) # Collect/restore data and train if training: # Get replay buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, # No parallel environments max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] # Replay buffer checkpointer rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) rb_checkpointer.initialize_or_restore() # Collect driver initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) # Optimize the performance by using tf functions initial_collect_driver.run = common.function(initial_collect_driver.run) # Collect initial replay data. if (global_step.numpy() == 0 and replay_buffer.num_frames() == 0): logging.info( 'Collecting experience for %d steps ' 'with a model-based policy.', initial_collect_steps) initial_collect_driver.run() rb_checkpointer.save(global_step=global_step.numpy()) # Dataset generates trajectories with shape [Bxslx...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=model_batch_size, num_steps=sequence_length + 1).prefetch(3) iterator = iter(dataset) # Get train model step def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) train_step = common.function(train_step) # Start training for iteration in range(num_iterations): loss = train_step() # Log training information if global_step.numpy() % log_interval == 0: logging.info('global steps = %d, model loss = %f', global_step.numpy(), loss.loss) # Get training metrics for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step.numpy()) # Evaluation if global_step.numpy() % eval_interval == 0: # Log evaluation metrics compute_summaries( eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, model_net=model_net, fps=10, image_keys=['camera', 'lidar', 'roadmap'], pixor_size=pixor_size) # Save checkpoints global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) model_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_load_fn=get_env, random_seed=None, # Params for collect num_environment_steps=25000000, collect_episodes_per_iteration=10, num_parallel_environments=10, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=10, learning_rate=1e-4, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=500, policy_checkpoint_interval=500, policy_save_interval=10000, log_interval=50, summary_interval=50, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=False, summarize_grads_and_vars=False): if random_seed is not None: tf.set_random_seed(random_seed) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') logging.info('Running %d episodes in parallel' % num_parallel_environments) logging.info('Collecting %d episodes per step' % collect_episodes_per_iteration) logging.info('Using replay buffer capacity of %d' % replay_buffer_capacity) train_summary_writer = tf.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn()) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn()] * num_parallel_environments)) actor_net, value_net = get_actor_and_value_network( tf_env.action_spec(), tf_env.observation_spec()) train_steps = tf.Variable(0) with tf.summary.record_if( lambda: tf.math.equal(train_steps % summary_interval, 0)): tf_agent = get_agent(time_step_spec=tf_env.time_step_spec(), action_spec=tf_env.action_spec(), actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, step_counter=train_steps, learning_rate=learning_rate) tf_agent.initialize() eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy step_metrics, train_metrics, eval_metrics = get_metrics( n_parallel_env=num_parallel_environments, num_eval_episodes=num_eval_episodes) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=train_steps, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=train_steps) saved_model = policy_saver.PolicySaver(eval_policy, train_step=train_steps) train_checkpointer.initialize_or_restore() collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) def train_step(): trajectories = replay_buffer.gather_all() return tf_agent.train(experience=trajectories) if use_tf_functions: # TODO(b/123828980): Enable once the cause for slowdown was identified. collect_driver.run = common.function(collect_driver.run, autograph=False) tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) start_time = time.time() collect_driver.run() collect_time += time.time() - start_time start_time = time.time() total_loss, _ = train_step() replay_buffer.clear() train_time += time.time() - start_time for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = %.3f, train_time = %.3f', collect_time, train_time) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % policy_save_interval == 0: saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) timed_at_step = global_step_val collect_time = 0 train_time = 0 # One final eval before exiting. metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', )
def train_eval( root_dir, experiment_name, # experiment name env_name='carla-v0', agent_name='sac', # agent's name num_iterations=int(1e7), actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), model_network_ctor_type='non-hierarchical', # model net input_names=['camera', 'lidar'], # names for inputs mask_names=['birdeye'], # names for masks preprocessing_combiner=tf.keras.layers.Add( ), # takes a flat list of tensors and combines them actor_lstm_size=(40, ), # lstm size for actor critic_lstm_size=(40, ), # lstm size for critic actor_output_fc_layers=(100, ), # lstm output critic_output_fc_layers=(100, ), # lstm output epsilon_greedy=0.1, # exploration parameter for DQN q_learning_rate=1e-3, # q learning rate for DQN ou_stddev=0.2, # exploration paprameter for DDPG ou_damping=0.15, # exploration parameter for DDPG dqda_clipping=None, # for DDPG exploration_noise_std=0.1, # exploration paramter for td3 actor_update_period=2, # for td3 # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=int(1e5), # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, initial_model_train_steps=100000, # initial model training batch_size=256, model_batch_size=32, # model training batch size sequence_length=4, # number of timesteps to train model actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, model_learning_rate=1e-4, # learning rate for model training td_errors_loss_fn=tf.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging num_images_per_summary=1, # images for each summary train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, gpu_allow_growth=True, # GPU memory growth gpu_memory_limit=None, # GPU memory limit action_repeat=1 ): # Name of single observation channel, ['camera', 'lidar', 'birdeye'] # Setup GPU gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit) ]) # Get train and eval directories root_dir = os.path.expanduser(root_dir) root_dir = os.path.join(root_dir, env_name, experiment_name) # Get summary writers summary_writer = tf.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() # Eval metrics eval_metrics = [ tf_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthEvalPolicy', buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() # Whether to record for summary with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create Carla environment if agent_name == 'latent_sac': py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names + mask_names, action_repeat=action_repeat) elif agent_name == 'dqn': py_env, eval_py_env = load_carla_env(env_name='carla-v0', discrete=True, obs_channels=input_names, action_repeat=action_repeat) else: py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names, action_repeat=action_repeat) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) fps = int(np.round(1.0 / (py_env.dt * action_repeat))) # Specs time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() ## Make tf agent if agent_name == 'latent_sac': # Get model network for latent sac if model_network_ctor_type == 'hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical elif model_network_ctor_type == 'non-hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical else: raise NotImplementedError model_net = model_network_ctor(input_names, input_names + mask_names) # Get the latent spec latent_size = model_net.latent_size latent_observation_spec = tensor_spec.TensorSpec((latent_size, ), dtype=tf.float32) latent_time_step_spec = ts.time_step_spec( observation_spec=latent_observation_spec) # Get actor and critic net actor_net = actor_distribution_network.ActorDistributionNetwork( latent_observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (latent_observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) # Build the inner SAC agent based on latent space inner_agent = sac_agent.SacAgent( latent_time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) inner_agent.initialize() # Build the latent sac agent tf_agent = latent_sac_agent.LatentSACAgent( time_step_spec, action_spec, inner_agent=inner_agent, model_network=model_net, model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), model_batch_size=model_batch_size, num_images_per_summary=num_images_per_summary, sequence_length=sequence_length, gradient_clipping=gradient_clipping, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, fps=fps) else: # Set up preprosessing layers for dictionary observation inputs preprocessing_layers = collections.OrderedDict() for name in input_names: preprocessing_layers[name] = Preprocessing_Layer(32, 256) if len(input_names) < 2: preprocessing_combiner = None if agent_name == 'dqn': q_rnn_net = q_rnn_network.QRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_rnn_net, epsilon_greedy=epsilon_greedy, n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=q_learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'ddpg' or agent_name == 'td3': actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) if agent_name == 'ddpg': tf_agent = ddpg_agent.DdpgAgent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) elif agent_name == 'td3': tf_agent = td3_agent.Td3Agent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'sac': actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=normal_projection_net) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_distribution_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math. squared_difference, # make critic loss dimension compatible gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) else: raise NotImplementedError tf_agent.initialize() # Get replay buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, # No parallel environments max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] # Train metrics env_steps = tf_metrics.EnvironmentSteps() average_return = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] # Get policies # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) eval_policy = tf_agent.policy initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) collect_policy = tf_agent.collect_policy # Checkpointers train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=2) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step, max_to_keep=2) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() # Collect driver initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) # Optimize the performance by using tf functions initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if (env_steps.result() == 0 or replay_buffer.num_frames() == 0): logging.info( 'Initializing replay buffer by collecting experience for %d steps' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() if agent_name == 'latent_sac': compute_summaries(eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=1, num_episodes_to_render=1, model_net=model_net, fps=10, image_keys=input_names + mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=1, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Dataset generates trajectories with shape [Bxslx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=sequence_length + 1).prefetch(3) iterator = iter(dataset) # Get train step def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) train_step = common.function(train_step) if agent_name == 'latent_sac': def train_model_step(): experience, _ = next(iterator) return tf_agent.train_model(experience) train_model_step = common.function(train_model_step) # Training initializations time_step = None time_acc = 0 env_steps_before = env_steps.result().numpy() # Start training for iteration in range(num_iterations): start_time = time.time() if agent_name == 'latent_sac' and iteration < initial_model_train_steps: train_model_step() else: # Run collect time_step, _ = collect_driver.run(time_step=time_step) # Train an iteration for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time # Log training information if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() # Get training metrics for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) # Evaluation if global_step.numpy() % eval_interval == 0: # Log evaluation metrics if agent_name == 'latent_sac': compute_summaries( eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, model_net=model_net, fps=10, image_keys=input_names + mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Save checkpoints global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def __init__( self, env, global_step, root_dir, step_metrics, name='Agent', is_environment=False, use_tf_functions=True, max_steps=250, replace_reward=True, non_negative_regret=False, id_num=0, block_budget_weight=0., # Architecture hparams use_rnn=True, learning_rate=1e-4, actor_fc_layers=(32, 32), value_fc_layers=(32, 32), lstm_size=(128, ), conv_filters=8, conv_kernel=3, scalar_fc=5, entropy_regularization=0., xy_dim=None, # Training & logging settings num_epochs=25, num_eval_episodes=5, num_parallel_envs=5, replay_buffer_capacity=1001, debug_summaries=True, summarize_grads_and_vars=True, ): """Initializes agent, replay buffer, metrics, and checkpointing. Args: env: An AdversarialTfPyEnvironment with specs and advesary specs. global_step: A tf variable tracking the global step. root_dir: Path to directory where metrics and checkpoints should be saved. step_metrics: A list of tf-agents metrics which represent the x-axis during training, such as the number of episodes or the number of environment steps. name: The name of this agent, e.g. 'Adversary'. is_environment: If True, will use the adversary specs from the environment and construct a network with additional inputs for the adversary. use_tf_functions: If True, will use tf.function to wrap the agent's train function. max_steps: The maximum number of steps the agent is allowed to interact with the environment in every data collection loop. replace_reward: If False, will not modify the reward stored in the agent's trajectories. This means the agent will be trained with the default environment reward rather than regret. non_negative_regret: If True, will ensure that the regret reward cannot be below 0. id_num: The ID number of this agent within the population of agents of the same type. I.e. this is adversary agent 3. block_budget_weight: Weight to place on the adversary's block budget reward. Default is 0 for no block budget. use_rnn: If True, will use an RNN within the network architecture. learning_rate: The learning rate used to initialize the optimizer for this agent. actor_fc_layers: The number and size of fully connected layers in the policy. value_fc_layers: The number and size of fully connected layers in the critic / value network. lstm_size: The number of LSTM cells in the RNN. conv_filters: The number of convolution filters. conv_kernel: The width of the convolution kernel. scalar_fc: The width of the fully-connected layer which inputs a scalar. entropy_regularization: Entropy regularization coefficient. xy_dim: Certain adversaries take in the current (x,y) position as a one-hot vector. In this case, the maximum value for x or y is required to create the one-hot representation. num_epochs: Number of epochs for computing PPO policy updates. num_eval_episodes: Number of evaluation episodes be eval step, used as batch size to initialize eval metrics. num_parallel_envs: Number of parallel environments used in trainin, used as batch size for training metrics and rewards. replay_buffer_capacity: Capacity of this agent's replay buffer. debug_summaries: Log additional summaries from the PPO agent. summarize_grads_and_vars: If True, logs gradient norms and variances in PPO agent. """ self.name = name self.id = id_num self.max_steps = max_steps self.is_environment = is_environment self.replace_reward = replace_reward self.non_negative_regret = non_negative_regret self.block_budget_weight = block_budget_weight with tf.name_scope(self.name): self.optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) logging.info('\tCalculating specs and building networks...') if is_environment: self.time_step_spec = env.adversary_time_step_spec self.action_spec = env.adversary_action_spec self.observation_spec = env.adversary_observation_spec (self.actor_net, self.value_net ) = multigrid_networks.construct_multigrid_networks( self.observation_spec, self.action_spec, use_rnns=use_rnn, actor_fc_layers=actor_fc_layers, value_fc_layers=value_fc_layers, lstm_size=lstm_size, conv_filters=conv_filters, conv_kernel=conv_kernel, scalar_fc=scalar_fc, scalar_name='time_step', scalar_dim=self.observation_spec['time_step'].maximum + 1, random_z=True, xy_dim=xy_dim) else: self.time_step_spec = env.time_step_spec() self.action_spec = env.action_spec() self.observation_spec = env.observation_spec() (self.actor_net, self.value_net ) = multigrid_networks.construct_multigrid_networks( self.observation_spec, self.action_spec, use_rnns=use_rnn, actor_fc_layers=actor_fc_layers, value_fc_layers=value_fc_layers, lstm_size=lstm_size, conv_filters=conv_filters, conv_kernel=conv_kernel, scalar_fc=scalar_fc) self.tf_agent = ppo_clip_agent.PPOClipAgent( self.time_step_spec, self.action_spec, self.optimizer, actor_net=self.actor_net, value_net=self.value_net, entropy_regularization=entropy_regularization, importance_ratio_clipping=0.2, normalize_observations=False, normalize_rewards=False, use_gae=True, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) self.tf_agent.initialize() self.eval_policy = self.tf_agent.policy self.collect_policy = self.tf_agent.collect_policy logging.info('\tAllocating replay buffer ...') self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( self.tf_agent.collect_data_spec, batch_size=num_parallel_envs, max_length=replay_buffer_capacity) logging.info('\t\tRB capacity: %i', self.replay_buffer.capacity) self.final_reward = tf.zeros(shape=(num_parallel_envs), dtype=tf.float32) self.enemy_max = tf.zeros(shape=(num_parallel_envs), dtype=tf.float32) # Creates train metrics self.step_metrics = step_metrics self.train_metrics = step_metrics + [ tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_envs, name=name + '_AverageEpisodeLength') ] self.eval_metrics = [ tf_metrics.AverageEpisodeLengthMetric( batch_size=num_eval_episodes, name=name + '_AverageEpisodeLength') ] if is_environment: self.env_train_metric = adversarial_eval.AdversarialEnvironmentScalar( batch_size=num_parallel_envs, name=name + '_AdversaryReward') self.env_eval_metric = adversarial_eval.AdversarialEnvironmentScalar( batch_size=num_eval_episodes, name=name + '_AdversaryReward') else: self.train_metrics.append( tf_metrics.AverageReturnMetric( batch_size=num_parallel_envs, name=name + '_AverageReturn')) self.eval_metrics.append( tf_metrics.AverageReturnMetric( batch_size=num_eval_episodes, name=name + '_AverageReturn')) self.metrics_group = metric_utils.MetricsGroup( self.train_metrics, name + '_train_metrics') self.observers = self.train_metrics + [ self.replay_buffer.add_batch ] self.train_dir = os.path.join(root_dir, 'train', name, str(id_num)) self.eval_dir = os.path.join(root_dir, 'eval', name, str(id_num)) self.train_checkpointer = common.Checkpointer( ckpt_dir=self.train_dir, agent=self.tf_agent, global_step=global_step, metrics=self.metrics_group, ) self.policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(self.train_dir, 'policy'), policy=self.eval_policy, global_step=global_step) self.saved_model = policy_saver.PolicySaver(self.eval_policy, train_step=global_step) self.saved_model_dir = os.path.join(root_dir, 'policy_saved_model', name, str(id_num)) self.train_checkpointer.initialize_or_restore() if use_tf_functions: self.tf_agent.train = common.function(self.tf_agent.train, autograph=False) self.total_loss = None self.extra_loss = None self.loss_divergence_counter = 0
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] eval_summary_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_py_env = suite_mujoco.load(env_name) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy( greedy_policy.GreedyPolicy(tf_agent.policy)) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] collect_policy = tf_agent.collect_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=5 * batch_size, num_steps=2).apply(tf.data.experimental.unbatch()).filter( _filter_invalid_transition).batch(batch_size).prefetch( batch_size * 5) dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) global_step_val = sess.run(global_step) if global_step_val == 0: # Initial eval of randomly initialized policy metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): total_loss, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def train_eval( root_dir, env_name='HalfCheetah-v2', eval_env_name=None, env_load_fn=suite_mujoco.load, num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, num_parallel_environments=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if num_parallel_environments > 1: tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) else: tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_env_name = eval_env_name or env_name eval_py_env = env_load_fn(eval_env_name) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = iterator.get_next() train_fn = common.function(tf_agent.train) train_op = train_fn(experience=trajectories) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(b/126239733) Remove once Periodically can be saved. common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = sess.run(global_step) time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, )
def play(root_dir, env, algorithm, checkpoint_name=None, greedy_predict=True, random_seed=None, num_episodes=10, sleep_time_per_step=0.01, record_file=None, use_tf_functions=True): """Play using the latest checkpoint under `train_dir`. The following example record the play of a trained model to a mp4 video: ```bash python -m alf.bin.play \ --root_dir=~/tmp/bullet_humanoid/ppo2/ppo2-11 \ --num_episodes=1 \ --record_file=ppo_bullet_humanoid.mp4 ``` Args: root_dir (str): same as the root_dir used for `train()` env (TFEnvironment): the environment algorithm (OnPolicyAlgorithm): the training algorithm checkpoint_name (str): name of the checkpoint (e.g. 'ckpt-12800`). If None, the latest checkpoint under train_dir will be used. greedy_predict (bool): use greedy action for evaluation. random_seed (None|int): random seed, a random seed is used if None num_episodes (int): number of episodes to play sleep_time_per_step (float): sleep so many seconds for each step record_file (str): if provided, video will be recorded to a file instead of shown on the screen. use_tf_functions (bool): whether to use tf.function """ root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') if random_seed is not None: random.seed(random_seed) np.random.seed(random_seed) tf.random.set_seed(random_seed) global_step = get_global_counter() driver = OnPolicyDriver(env=env, algorithm=algorithm, training=False, greedy_predict=greedy_predict) ckpt_dir = os.path.join(train_dir, 'algorithm') checkpoint = tf.train.Checkpoint(algorithm=algorithm, metrics=metric_utils.MetricsGroup( driver.get_metrics(), 'metrics'), global_step=global_step) if checkpoint_name is not None: ckpt_path = os.path.join(ckpt_dir, checkpoint_name) else: ckpt_path = tf.train.latest_checkpoint(ckpt_dir) if ckpt_path is not None: logging.info("Restore from checkpoint %s" % ckpt_path) checkpoint.restore(ckpt_path) else: logging.info("Checkpoint is not found at %s" % ckpt_dir) if not use_tf_functions: tf.config.experimental_run_functions_eagerly(True) recorder = None if record_file is not None: recorder = VideoRecorder(env.pyenv.envs[0], path=record_file) else: # pybullet_envs need to render() before reset() to enable mode='human' env.pyenv.envs[0].render(mode='human') env.reset() if recorder: recorder.capture_frame() time_step = driver.get_initial_time_step() policy_state = driver.get_initial_policy_state() episode_reward = 0. episode_length = 0 episodes = 0 while episodes < num_episodes: time_step, policy_state = driver.run(max_num_steps=1, time_step=time_step, policy_state=policy_state) if recorder: recorder.capture_frame() else: env.pyenv.envs[0].render(mode='human') time.sleep(sleep_time_per_step) episode_reward += float(time_step.reward) if time_step.is_last(): logging.info("episode_length=%s episode_reward=%s" % (episode_length, episode_reward)) episode_reward = 0. episode_length = 0. episodes += 1 else: episode_length += 1 if recorder: recorder.close() env.reset()
def train_eval( root_dir, tf_master='', env_name='HalfCheetah-v2', env_load_fn=suite_mujoco.load, random_seed=0, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, # Params for collect num_environment_steps=10000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-4, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=50, rb_checkpoint_interval=200, log_interval=50, summary_interval=50, summaries_flush_secs=1, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ batched_py_metric.BatchedPyMetric( AverageReturnMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), batched_py_metric.BatchedPyMetric( AverageEpisodeLengthMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), ] eval_summary_writer_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf.compat.v1.set_random_seed(random_seed) eval_py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) environment_steps_metric = tf_metrics.EnvironmentSteps() environment_steps_count = environment_steps_metric.result() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric( batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments), ] # Add to replay buffer and other agent specific observers. replay_buffer_observer = [replay_buffer.add_batch] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_buffer_observer + train_metrics, num_episodes=collect_episodes_per_iteration).run() trajectories = replay_buffer.gather_all() train_op, _ = tf_agent.train(experience=trajectories) with tf.control_dependencies([train_op]): clear_replay_op = replay_buffer.clear() with tf.control_dependencies([clear_replay_op]): train_op = tf.identity(train_op) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics)) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session(tf_master) as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) collect_time = 0 train_time = 0 timed_at_step = sess.run(global_step) steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) while sess.run(environment_steps_count) < num_environment_steps: global_step_val = sess.run(global_step) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op) start_time = time.time() sess.run(collect_op) collect_time += time.time() - start_time start_time = time.time() total_loss, _ = sess.run([train_op, summary_ops]) train_time += time.time() - start_time global_step_val = sess.run(global_step) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) # One final eval before exiting. metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op)
def train_eval( ############################################## # types of params: # 0: specific to algorithm (gin file 0) # 1: specific to environment (gin file 1) # 2: specific to experiment (gin file 2 + command line) # Note: there are other important params # in eg ModelDistributionNetwork that the gin files specify # like sparse vs dense rewards, latent dimensions, etc. ############################################## # basic params for running/logging experiment root_dir, # 2 experiment_name, # 2 num_iterations=int(1e7), # 2 seed=1, # 2 gpu_allow_growth=False, # 2 gpu_memory_limit=None, # 2 verbose=True, # 2 policy_checkpoint_freq_in_iter=100, # policies needed for future eval # 2 train_checkpoint_freq_in_iter=0, #default don't save # 2 rb_checkpoint_freq_in_iter=0, #default don't save # 2 logging_freq_in_iter=10, # printing to terminal # 2 summary_freq_in_iter=10, # saving to tb # 2 num_images_per_summary=2, # 2 summaries_flush_secs=10, # 2 max_episode_len_override=None, # 2 num_trials_to_render=1, # 2 # environment, action mode, etc. env_name='HalfCheetah-v2', # 1 action_repeat=1, # 1 action_mode='joint_position', # joint_position or joint_delta_position # 1 double_camera=False, # camera input # 1 universe='gym', # default task_reward_dim=1, # default # dims for all networks actor_fc_layers=(256, 256), # 1 critic_obs_fc_layers=None, # 1 critic_action_fc_layers=None, # 1 critic_joint_fc_layers=(256, 256), # 1 num_repeat_when_concatenate=None, # 1 # networks critic_input='state', # 0 actor_input='state', # 0 # specifying tasks and eval episodes_per_trial=1, # 2 num_train_tasks=10, # 2 num_eval_tasks=10, # 2 num_eval_trials=10, # 2 eval_interval=10, # 2 eval_on_holdout_tasks=True, # 2 # data collection/buffer init_collect_trials_per_task=None, # 2 collect_trials_per_task=None, # 2 num_tasks_to_collect_per_iter=5, # 2 replay_buffer_capacity=int(1e5), # 2 # training init_model_train_ratio=0.8, # 2 model_train_ratio=1, # 2 model_train_freq=1, # 2 ac_train_ratio=1, # 2 ac_train_freq=1, # 2 num_tasks_per_train=5, # 2 train_trials_per_task=5, # 2 model_bs_in_steps=256, # 2 ac_bs_in_steps=128, # 2 # default AC learning rates, gamma, etc. target_update_tau=0.005, target_update_period=1, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, model_learning_rate=1e-4, td_errors_loss_fn=functools.partial( tf.compat.v1.losses.mean_squared_error, weights=0.5), gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, log_image_strips=False, stop_model_training=1E10, eval_only=False, # evaluate checkpoints ONLY log_image_observations=False, load_offline_data=False, # whether to use offline data offline_data_dir=None, # replay buffer's dir offline_episode_len=None, # episode len of episodes stored in rb offline_ratio=0, # ratio of data that is from offline buffer ): g = tf.Graph() # register all gym envs max_steps_dict = { "HalfCheetahVel-v0": 50, "SawyerReach-v0": 40, "SawyerReachMT-v0": 40, "SawyerPeg-v0": 40, "SawyerPegMT-v0": 40, "SawyerPegMT4box-v0": 40, "SawyerShelfMT-v0": 40, "SawyerKitchenMT-v0": 40, "SawyerShelfMT-v2": 40, "SawyerButtons-v0": 40, } if max_episode_len_override: max_steps_dict[env_name] = max_episode_len_override register_all_gym_envs(max_steps_dict) # set max_episode_len based on our env max_episode_len = max_steps_dict[env_name] ###################################################### # Calculate additional params ###################################################### # convert to number of steps env_steps_per_trial = episodes_per_trial * max_episode_len real_env_steps_per_trial = episodes_per_trial * (max_episode_len + 1) env_steps_per_iter = num_tasks_to_collect_per_iter * collect_trials_per_task * env_steps_per_trial per_task_collect_steps = collect_trials_per_task * env_steps_per_trial # initial collect + train init_collect_env_steps = num_train_tasks * init_collect_trials_per_task * env_steps_per_trial init_model_train_steps = int(init_collect_env_steps * init_model_train_ratio) # collect + train collect_env_steps_per_iter = num_tasks_to_collect_per_iter * per_task_collect_steps model_train_steps_per_iter = int(env_steps_per_iter * model_train_ratio) ac_train_steps_per_iter = int(env_steps_per_iter * ac_train_ratio) # other global_steps_per_iter = collect_env_steps_per_iter + model_train_steps_per_iter + ac_train_steps_per_iter sample_episodes_per_task = train_trials_per_task * episodes_per_trial # number of episodes to sample from each replay model_bs_in_trials = model_bs_in_steps // real_env_steps_per_trial # assertions that make sure parameters make sense assert model_bs_in_trials > 0, "model batch size need to be at least as big as one full real trial" assert num_tasks_to_collect_per_iter <= num_train_tasks, "when sampling replace=False" assert num_tasks_per_train * train_trials_per_task >= model_bs_in_trials, "not enough data for one batch model train" assert num_tasks_per_train * train_trials_per_task * env_steps_per_trial >= ac_bs_in_steps, "not enough data for one batch ac train" ###################################################### # Print a summary of params ###################################################### MELD_summary_string = f"""\n\n\n ============================================================== ============================================================== \n MELD algorithm summary: * each trial consists of {episodes_per_trial} episodes * episode length: {max_episode_len}, trial length: {env_steps_per_trial} * {num_train_tasks} train tasks, {num_eval_tasks} eval tasks, hold-out: {eval_on_holdout_tasks} * environment: {env_name} For each of {num_train_tasks} tasks: Do {init_collect_trials_per_task} trials of initial collect (total {init_collect_env_steps} env steps) Do {init_model_train_steps} steps of initial model training For i in range(inf): For each of {num_tasks_to_collect_per_iter} randomly selected tasks: Do {collect_trials_per_task} trials of collect (which is {collect_trials_per_task*env_steps_per_trial} env steps per task) (for a total of {num_tasks_to_collect_per_iter*collect_trials_per_task*env_steps_per_trial} env steps in the iteration) if i % model_train_freq(={model_train_freq}): Do {model_train_steps_per_iter} steps of model training - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials. - pick randomly {model_bs_in_trials} trials, train model on whole trials. if i % ac_train_freq(={ac_train_freq}): Do {ac_train_steps_per_iter} steps of ac training - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials. - pick randomly {ac_bs_in_steps} transitions, not including between trial transitions, to train ac. * Other important params: Evaluate policy every {eval_interval} iters, equivalent to {global_steps_per_iter*eval_interval/1000:.1f}k global steps Average evaluation across {num_eval_trials} trials Save summary to tensorboard every {summary_freq_in_iter} iters, equivalent to {global_steps_per_iter*summary_freq_in_iter/1000:.1f}k global steps Checkpoint: - training checkpoint every {train_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*train_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint - policy checkpoint every {policy_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*policy_checkpoint_freq_in_iter//1000}k global steps, keep all checkpoints - replay buffer checkpoint every {rb_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*rb_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint \n ============================================================= ============================================================= """ print(MELD_summary_string) time.sleep(1) ###################################################### # Seed + name + GPU configs + directories for saving ###################################################### np.random.seed(int(seed)) experiment_name += "_seed" + str(seed) gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit) ]) train_eval_dir = get_train_eval_dir(root_dir, universe, env_name, experiment_name) train_dir = os.path.join(train_eval_dir, 'train') eval_dir = os.path.join(train_eval_dir, 'eval') eval_dir_2 = os.path.join(train_eval_dir, 'eval2') ###################################################### # Train and Eval Summary Writers ###################################################### train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_summary_flush_op = eval_summary_writer.flush() eval_logger = Logger(eval_dir_2) ###################################################### # Train and Eval metrics ###################################################### eval_buffer_size = num_eval_trials * episodes_per_trial * max_episode_len # across all eval trials in each evaluation eval_metrics = [] for position in range( episodes_per_trial ): # have metrics for each episode position, to track whether it is learning eval_metrics_pos = [ py_metrics.AverageReturnMetric(name='c_AverageReturnEval_' + str(position), buffer_size=eval_buffer_size), py_metrics.AverageEpisodeLengthMetric( name='f_AverageEpisodeLengthEval_' + str(position), buffer_size=eval_buffer_size), custom_metrics.AverageScoreMetric( name="d_AverageScoreMetricEval_" + str(position), buffer_size=eval_buffer_size), ] eval_metrics.extend(eval_metrics_pos) train_buffer_size = num_train_tasks * episodes_per_trial train_metrics = [ tf_metrics.NumberOfEpisodes(name='NumberOfEpisodes'), tf_metrics.EnvironmentSteps(name='EnvironmentSteps'), tf_py_metric.TFPyMetric( py_metrics.AverageReturnMetric(name="a_AverageReturnTrain", buffer_size=train_buffer_size)), tf_py_metric.TFPyMetric( py_metrics.AverageEpisodeLengthMetric( name="e_AverageEpisodeLengthTrain", buffer_size=train_buffer_size)), tf_py_metric.TFPyMetric( custom_metrics.AverageScoreMetric(name="b_AverageScoreTrain", buffer_size=train_buffer_size)), ] global_step = tf.compat.v1.train.get_or_create_global_step( ) # will be use to record number of model grad steps + ac grad steps + env_step log_cond = get_log_condition_tensor( global_step, init_collect_trials_per_task, env_steps_per_trial, num_train_tasks, init_model_train_steps, collect_trials_per_task, num_tasks_to_collect_per_iter, model_train_steps_per_iter, ac_train_steps_per_iter, summary_freq_in_iter, eval_interval) with tf.compat.v2.summary.record_if(log_cond): ###################################################### # Create env ###################################################### py_env, eval_py_env, train_tasks, eval_tasks = load_environments( universe, action_mode, env_name=env_name, observations_whitelist=['state', 'pixels', "env_info"], action_repeat=action_repeat, num_train_tasks=num_train_tasks, num_eval_tasks=num_eval_tasks, eval_on_holdout_tasks=eval_on_holdout_tasks, return_multiple_tasks=True, ) override_reward_func = None if load_offline_data: py_env.set_task_dict(train_tasks) override_reward_func = py_env.override_reward_func tf_env = tf_py_environment.TFPyEnvironment(py_env, isolation=True) # Get data specs from env time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() original_control_timestep = get_control_timestep(eval_py_env) # fps control_timestep = original_control_timestep * float(action_repeat) render_fps = int(np.round(1.0 / original_control_timestep)) ###################################################### # Latent variable model ###################################################### if verbose: print("-- start constructing model networks --") model_net = ModelDistributionNetwork( double_camera=double_camera, observation_spec=observation_spec, num_repeat_when_concatenate=num_repeat_when_concatenate, task_reward_dim=task_reward_dim, episodes_per_trial=episodes_per_trial, max_episode_len=max_episode_len ) # rest of arguments provided via gin if verbose: print("-- finish constructing AC networks --") ###################################################### # Compressor Network for Actor/Critic # The model's compressor is also used by the AC # compressor function: images --> features ###################################################### compressor_net = model_net.compressor ###################################################### # Specs for Actor and Critic ###################################################### if actor_input == 'state': actor_state_size = observation_spec['state'].shape[0] elif actor_input == 'latentSample': actor_state_size = model_net.state_size elif actor_input == "latentDistribution": actor_state_size = 2 * model_net.state_size # mean and (diagonal) variance of gaussian, of two latents else: raise NotImplementedError actor_input_spec = tensor_spec.TensorSpec((actor_state_size, ), dtype=tf.float32) if critic_input == 'state': critic_state_size = observation_spec['state'].shape[0] elif critic_input == 'latentSample': critic_state_size = model_net.state_size elif critic_input == "latentDistribution": critic_state_size = 2 * model_net.state_size # mean and (diagonal) variance of gaussian, of two latents else: raise NotImplementedError critic_input_spec = tensor_spec.TensorSpec((critic_state_size, ), dtype=tf.float32) ###################################################### # Actor and Critic Networks ###################################################### if verbose: print("-- start constructing Actor and Critic networks --") actor_net = actor_distribution_network.ActorDistributionNetwork( actor_input_spec, action_spec, fc_layer_params=actor_fc_layers, ) critic_net = critic_network.CriticNetwork( (critic_input_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) if verbose: print("-- finish constructing AC networks --") print("-- start constructing agent --") ###################################################### # Create the agent ###################################################### which_posterior_overwrite = None which_reward_overwrite = None meld_agent = MeldAgent( # specs time_step_spec=time_step_spec, action_spec=action_spec, # step counter train_step_counter= global_step, # will count number of model training steps # networks actor_network=actor_net, critic_network=critic_net, model_network=model_net, compressor_network=compressor_net, # optimizers actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), # target update target_update_tau=target_update_tau, target_update_period=target_update_period, # inputs critic_input=critic_input, actor_input=actor_input, # bs stuff model_batch_size=model_bs_in_steps, ac_batch_size=ac_bs_in_steps, # other num_tasks_per_train=num_tasks_per_train, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, control_timestep=control_timestep, num_images_per_summary=num_images_per_summary, task_reward_dim=task_reward_dim, episodes_per_trial=episodes_per_trial, # offline data override_reward_func=override_reward_func, offline_ratio=offline_ratio, ) if verbose: print("-- finish constructing agent --") ###################################################### # Replay buffers + observers to add data to them ###################################################### replay_buffers = [] replay_observers = [] for _ in range(num_train_tasks): replay_buffer_episodic = episodic_replay_buffer.EpisodicReplayBuffer( meld_agent.collect_policy. trajectory_spec, # spec of each point stored in here (i.e. Trajectory) capacity=replay_buffer_capacity, completed_only= True, # in as_dataset, if num_steps is None, this means return full episodes # device='GPU:0', # gpu not supported for some reason begin_episode_fn=lambda traj: traj.is_first()[ 0], # first step of seq we add should be is_first end_episode_fn=lambda traj: traj.is_last()[ 0], # last step of seq we add should be is_last dataset_drop_remainder= True, #`as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items ) replay_buffer = StatefulEpisodicReplayBuffer( replay_buffer_episodic) # adding num_episodes here is bad replay_buffers.append(replay_buffer) replay_observers.append([replay_buffer.add_sequence]) if load_offline_data: # for each task, has a separate replay buffer for relabeled data replay_buffers_withRelabel = [] replay_observers_withRelabel = [] for _ in range(num_train_tasks): replay_buffer_episodic_withRelabel = episodic_replay_buffer.EpisodicReplayBuffer( meld_agent.collect_policy. trajectory_spec, # spec of each point stored in here (i.e. Trajectory) capacity=replay_buffer_capacity, completed_only= True, # in as_dataset, if num_steps is None, this means return full episodes # device='GPU:0', # gpu not supported for some reason begin_episode_fn=lambda traj: traj.is_first()[ 0], # first step of seq we add should be is_first end_episode_fn=lambda traj: traj.is_last()[ 0], # last step of seq we add should be is_last dataset_drop_remainder=True, # `as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items ) replay_buffer_withRelabel = StatefulEpisodicReplayBuffer( replay_buffer_episodic_withRelabel ) # adding num_episodes here is bad replay_buffers_withRelabel.append(replay_buffer_withRelabel) replay_observers_withRelabel.append( [replay_buffer_withRelabel.add_sequence]) if verbose: print("-- finish constructing replay buffers --") print("-- start constructing policies and collect ops --") ###################################################### # Policies ##################################################### # init collect policy (random) init_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) # eval eval_py_policy = py_tf_policy.PyTFPolicy(meld_agent.policy) ################################################################################ # Collect ops : use policies to get data + have the observer put data into corresponding RB ################################################################################ #init collection (with random policy) init_collect_ops = [] for task_idx in range(num_train_tasks): # put init data into the rb + track with the train metric observers = replay_observers[task_idx] + train_metrics # initial collect op init_collect_op = DynamicTrialDriver( tf_env, init_collect_policy, num_trials_to_collect=init_collect_trials_per_task, observers=observers, episodes_per_trial= episodes_per_trial, # policy state will not be reset within these episodes max_episode_len=max_episode_len, ).run() # collect one trial init_collect_ops.append(init_collect_op) # data collection for training (with collect policy) collect_ops = [] for task_idx in range(num_train_tasks): collect_op = DynamicTrialDriver( tf_env, meld_agent.collect_policy, num_trials_to_collect=collect_trials_per_task, observers=replay_observers[task_idx] + train_metrics, # put data into 1st RB + track with 1st pol metrics episodes_per_trial= episodes_per_trial, # policy state will not be reset within these episodes max_episode_len=max_episode_len, ).run() # collect one trial collect_ops.append(collect_op) if verbose: print("-- finish constructing policies and collect ops --") print("-- start constructing replay buffer->training pipeline --") ###################################################### # replay buffer --> dataset --> iterate to get trajecs for training ###################################################### # get some data from all task replay buffers (even though won't actually train on all of them) dataset_iterators = [] all_tasks_trajectories_fromdense = [] for task_idx in range(num_train_tasks): dataset = replay_buffers[task_idx].as_dataset( sample_batch_size= sample_episodes_per_task, # number of episodes to sample num_steps=max_episode_len + 1 ).prefetch( 3 ) # +1 to include the last state: a trajectory with n transition has n+1 states # iterator to go through the data dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) dataset_iterators.append(dataset_iterator) # get sample_episodes_per_task sequences, each of length num_steps trajectories_task_i, _ = dataset_iterator.get_next() all_tasks_trajectories_fromdense.append(trajectories_task_i) if load_offline_data: # have separate dataset for relabel data dataset_iterators_withRelabel = [] all_tasks_trajectories_fromdense_withRelabel = [] for task_idx in range(num_train_tasks): dataset = replay_buffers_withRelabel[task_idx].as_dataset( sample_batch_size= sample_episodes_per_task, # number of episodes to sample num_steps=offline_episode_len + 1 ).prefetch( 3 ) # +1 to include the last state: a trajectory with n transition has n+1 states # iterator to go through the data dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) dataset_iterators_withRelabel.append(dataset_iterator) # get sample_episodes_per_task sequences, each of length num_steps trajectories_task_i, _ = dataset_iterator.get_next() all_tasks_trajectories_fromdense_withRelabel.append( trajectories_task_i) if verbose: print("-- finish constructing replay buffer->training pipeline --") print("-- start constructing model and AC training ops --") ###################################### # Decoding latent samples into rewards ###################################### latent_samples_1_ph = tf.compat.v1.placeholder( dtype=tf.float32, shape=(None, None, meld_agent._model_network.latent1_size)) latent_samples_2_ph = tf.compat.v1.placeholder( dtype=tf.float32, shape=(None, None, meld_agent._model_network.latent2_size)) decode_rews_op = meld_agent._model_network.decode_latents_into_reward( latent_samples_1_ph, latent_samples_2_ph) ###################################### # Model/Actor/Critic train + summary ops ###################################### # train AC on data from replay buffer if load_offline_data: ac_train_op = meld_agent.train_ac_meld( all_tasks_trajectories_fromdense, all_tasks_trajectories_fromdense_withRelabel) else: ac_train_op = meld_agent.train_ac_meld( all_tasks_trajectories_fromdense) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) if verbose: print("-- finish constructing AC training ops --") ############################ # Model train + summary ops ############################ # train model on data from replay buffer if load_offline_data: model_train_op, check_step_types = meld_agent.train_model_meld( all_tasks_trajectories_fromdense, all_tasks_trajectories_fromdense_withRelabel) else: model_train_op, check_step_types = meld_agent.train_model_meld( all_tasks_trajectories_fromdense) model_summary_ops, model_summary_ops_2 = [], [] for summary_op in tf.compat.v1.summary.all_v2_summary_ops(): if summary_op not in summary_ops: model_summary_ops.append(summary_op) if verbose: print("-- finish constructing model training ops --") print("-- start constructing checkpointers --") ######################## # Eval metrics ######################## with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) ######################## # Create savers ######################## train_config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=False) eval_config_saver = gin.tf.GinConfigSaverHook(eval_dir, summarize_config=False) ######################## # Create checkpointers ######################## train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=meld_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=1) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=meld_agent.policy, global_step=global_step, max_to_keep=99999999999 ) # keep many policy checkpoints, in case of future eval rb_checkpointers = [] for buffer_idx in range(len(replay_buffers)): rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffers/', "task" + str(buffer_idx)), max_to_keep=1, replay_buffer=replay_buffers[buffer_idx]) rb_checkpointers.append(rb_checkpointer) if load_offline_data: # for LOADING data not for checkpointing. No new data going in anyways rb_checkpointers_withRelabel = [] for buffer_idx in range(len(replay_buffers_withRelabel)): ckpt_dir = os.path.join(offline_data_dir, "task" + str(buffer_idx)) rb_checkpointer = common.Checkpointer( ckpt_dir=ckpt_dir, max_to_keep=99999999999, replay_buffer=replay_buffers_withRelabel[buffer_idx]) rb_checkpointers_withRelabel.append(rb_checkpointer) # Notice: these replay buffers need to follow the same sequence of tasks as the current one if verbose: print("-- finish constructing checkpointers --") print("-- start main training loop --") with tf.compat.v1.Session() as sess: ######################## # Initialize ######################## if eval_only: sess.run(eval_summary_writer.init()) load_eval_log( train_eval_dir=train_eval_dir, meld_agent=meld_agent, global_step=global_step, sess=sess, eval_metrics=eval_metrics, eval_py_env=eval_py_env, eval_py_policy=eval_py_policy, num_eval_trials=num_eval_trials, max_episode_len=max_episode_len, episodes_per_trial=episodes_per_trial, log_image_strips=log_image_strips, num_trials_to_render=num_trials_to_render, train_tasks= train_tasks, # in case want to eval on a train task eval_tasks=eval_tasks, model_net=model_net, render_fps=render_fps, decode_rews_op=decode_rews_op, latent_samples_1_ph=latent_samples_1_ph, latent_samples_2_ph=latent_samples_2_ph, ) return # Initialize checkpointing train_checkpointer.initialize_or_restore(sess) for rb_checkpointer in rb_checkpointers: rb_checkpointer.initialize_or_restore(sess) if load_offline_data: for rb_checkpointer in rb_checkpointers_withRelabel: rb_checkpointer.initialize_or_restore(sess) # Initialize dataset iterators for dataset_iterator in dataset_iterators: sess.run(dataset_iterator.initializer) if load_offline_data: for dataset_iterator in dataset_iterators_withRelabel: sess.run(dataset_iterator.initializer) # Initialize variables common.initialize_uninitialized_variables(sess) # Initialize summary writers sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Initialize savers train_config_saver.after_create_session(sess) eval_config_saver.after_create_session(sess) # Get value of step counter global_step_val = sess.run(global_step) if verbose: print("====== finished initialization ======") ################################################################ # If this is start of new exp (i.e., 1st step) and not continuing old exp # eval rand policy + do initial data collection ################################################################ fresh_start = (global_step_val == 0) if fresh_start: ######################## # Evaluate initial policy ######################## if eval_interval: logging.info( '\n\nDoing evaluation of initial policy on %d trials with randomly sampled tasks', num_eval_trials) perform_eval_and_summaries_meld( eval_metrics, eval_py_env, eval_py_policy, num_eval_trials, max_episode_len, episodes_per_trial, log_image_strips=log_image_strips, num_trials_to_render=num_eval_tasks, eval_tasks=eval_tasks, latent1_size=model_net.latent1_size, latent2_size=model_net.latent2_size, logger=eval_logger, global_step_val=global_step_val, render_fps=render_fps, decode_rews_op=decode_rews_op, latent_samples_1_ph=latent_samples_1_ph, latent_samples_2_ph=latent_samples_2_ph, log_image_observations=log_image_observations, ) sess.run(eval_summary_flush_op) logging.info( 'Done with evaluation of initial (random) policy.\n\n') ######################## # Initial data collection ######################## logging.info( '\n\nGlobal step %d: Beginning init collect op with random policy. Collecting %dx {%d, %d} trials for each task', global_step_val, init_collect_trials_per_task, max_episode_len, episodes_per_trial) init_increment_global_step_op = global_step.assign_add( env_steps_per_trial * init_collect_trials_per_task) for task_idx in range(num_train_tasks): logging.info('on task %d / %d', task_idx + 1, num_train_tasks) py_env.set_task_for_env(train_tasks[task_idx]) sess.run([ init_collect_ops[task_idx], init_increment_global_step_op ]) # incremented gs in granularity of task rb_checkpointer.save(global_step=global_step_val) logging.info('Finished init collect.\n\n') else: logging.info( '\n\nGlobal step %d from loaded experiment: Skipping init collect op.\n\n', global_step_val) ######################### # Create calls ######################### # [1] calls for running the policies to collect training data collect_calls = [] increment_global_step_op = global_step.assign_add( env_steps_per_trial * collect_trials_per_task) for task_idx in range(num_train_tasks): collect_calls.append( sess.make_callable( [collect_ops[task_idx], increment_global_step_op])) # [2] call for doing a training step (A + C) ac_train_step_call = sess.make_callable([ac_train_op, summary_ops]) # [3] call for doing a training step (model) model_train_step_call = sess.make_callable( [model_train_op, check_step_types, model_summary_ops]) # [4] call for evaluating what global_step number we're on global_step_call = sess.make_callable(global_step) # reset keeping track of steps/time timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') with train_summary_writer.as_default( ), tf.compat.v2.summary.record_if(True): steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) ################################# # init model training ################################# if fresh_start: logging.info( '\n\nPerforming %d steps of init model training, each step on %d random tasks', init_model_train_steps, num_tasks_per_train) for i in range(init_model_train_steps): temp_start = time.time() if i % 100 == 0: print(".... init model training ", i, "/", init_model_train_steps) # init model training total_loss_value_model, check_step_types, _ = model_train_step_call( ) if PRINT_TIMING: print("single model train step: ", time.time() - temp_start) if verbose: print("\n\n\n-- start training loop --\n") ################################# # Training Loop ################################# start_time = time.time() for iteration in range(num_iterations): if iteration > 0: g.finalize() # print("\n\n\niter", iteration, sess.run(curr_iter)) print("global step", global_step_call()) logging.info("Iteration: %d, Global step: %d\n", iteration, global_step_val) #################### # collect data #################### logging.info( '\nStarting batch data collection. Collecting %d {%d, %d} trials for each of %d tasks', collect_trials_per_task, max_episode_len, episodes_per_trial, num_tasks_to_collect_per_iter) # randomly select tasks to collect this iteration list_of_collect_task_idxs = np.random.choice( len(train_tasks), num_tasks_to_collect_per_iter, replace=False) for count, task_idx in enumerate(list_of_collect_task_idxs): logging.info('on randomly selected task %d / %d', count + 1, num_tasks_to_collect_per_iter) # set task for the env py_env.set_task_for_env(train_tasks[task_idx]) # collect data with collect policy _, policy_state_val = collect_calls[task_idx]() logging.info('Finish data collection. Global step: %d\n', global_step_call()) #################### # train model #################### if (iteration == 0) or ((iteration % model_train_freq == 0) and (global_step_val < stop_model_training)): logging.info( '\n\nPerforming %d steps of model training, each on %d random tasks', model_train_steps_per_iter, num_tasks_per_train) for model_iter in range(model_train_steps_per_iter): temp_start_2 = time.time() # train model total_loss_value_model, _, _ = model_train_step_call() # print("is logging step", model_iter, sess.run(is_logging_step)) if PRINT_TIMING: print("2: single model train step: ", time.time() - temp_start_2) logging.info('Finish model training. Global step: %d\n', global_step_call()) else: print("SKIPPING MODEL TRAINING") #################### # train actor critic #################### if iteration % ac_train_freq == 0: logging.info( '\n\nPerforming %d steps of AC training, each on %d random tasks \n\n', ac_train_steps_per_iter, num_tasks_per_train) for ac_iter in range(ac_train_steps_per_iter): temp_start_2_ac = time.time() # train ac total_loss_value_ac, _ = ac_train_step_call() if PRINT_TIMING: print("2: single AC train step: ", time.time() - temp_start_2_ac) logging.info('Finish AC training. Global step: %d\n', global_step_call()) # add up time time_acc += time.time() - start_time #################### # logging/summaries #################### ### Eval if eval_interval and (iteration % eval_interval == 0): logging.info( '\n\nDoing evaluation of trained policy on %d trials with randomly sampled tasks', num_eval_trials) perform_eval_and_summaries_meld( eval_metrics, eval_py_env, eval_py_policy, num_eval_trials, max_episode_len, episodes_per_trial, log_image_strips=log_image_strips, num_trials_to_render= num_trials_to_render, # hardcoded: or gif will get too long eval_tasks=eval_tasks, latent1_size=model_net.latent1_size, latent2_size=model_net.latent2_size, logger=eval_logger, global_step_val=global_step_call(), render_fps=render_fps, decode_rews_op=decode_rews_op, latent_samples_1_ph=latent_samples_1_ph, latent_samples_2_ph=latent_samples_2_ph, log_image_observations=log_image_observations, ) ### steps_per_second_summary global_step_val = global_step_call() if logging_freq_in_iter and (iteration % logging_freq_in_iter == 0): # log step number + speed (steps/sec) logging.info( 'step = %d, loss = %f', global_step_val, total_loss_value_ac.loss + total_loss_value_model.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f env_steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) # reset keeping track of steps/time timed_at_step = global_step_val time_acc = 0 ### train_checkpoint if train_checkpoint_freq_in_iter and ( iteration % train_checkpoint_freq_in_iter == 0): train_checkpointer.save(global_step=global_step_val) ### policy_checkpointer if policy_checkpoint_freq_in_iter and ( iteration % policy_checkpoint_freq_in_iter == 0): policy_checkpointer.save(global_step=global_step_val) ### rb_checkpointer if rb_checkpoint_freq_in_iter and ( iteration % rb_checkpoint_freq_in_iter == 0): for rb_checkpointer in rb_checkpointers: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, environment_name="broken_reacher", num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), initial_collect_steps=10000, real_initial_collect_steps=10000, collect_steps_per_iteration=1, real_collect_interval=10, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, classifier_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=False, train_on_real=False, delta_r_warmup=0, random_seed=0, checkpoint_dir=None, ): """A simple train and eval for SAC.""" np.random.seed(random_seed) tf.random.set_seed(random_seed) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, "train") eval_dir = os.path.join(root_dir, "eval") train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) if environment_name == "broken_reacher": get_env_fn = darc_envs.get_broken_reacher_env elif environment_name == "half_cheetah_obstacle": get_env_fn = darc_envs.get_half_cheetah_direction_env elif environment_name == "inverted_pendulum": get_env_fn = darc_envs.get_inverted_pendulum_env elif environment_name.startswith("broken_joint"): base_name = environment_name.split("broken_joint_")[1] get_env_fn = functools.partial(darc_envs.get_broken_joint_env, env_name=base_name) elif environment_name.startswith("falling"): base_name = environment_name.split("falling_")[1] get_env_fn = functools.partial(darc_envs.get_falling_env, env_name=base_name) else: raise NotImplementedError("Unknown environment: %s" % environment_name) eval_name_list = ["sim", "real"] eval_env_list = [get_env_fn(mode) for mode in eval_name_list] eval_metrics_list = [] for name in eval_name_list: eval_metrics_list.append([ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, name="AverageReturn_%s" % name), ]) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env_real = get_env_fn("real") if train_on_real: tf_env = get_env_fn("real") else: tf_env = get_env_fn("sim") time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=( tanh_normal_projection_network.TanhNormalProjectionNetwork), ) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer="glorot_uniform", last_kernel_initializer="glorot_uniform", ) classifier = classifiers.build_classifier(observation_spec, action_spec) tf_agent = darc_agent.DarcAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, classifier=classifier, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), classifier_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=classifier_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, ) replay_observer = [replay_buffer.add_batch] real_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, ) real_replay_observer = [real_replay_buffer.add_batch] sim_train_metrics = [ tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesSim"), tf_metrics.EnvironmentSteps(name="EnvironmentStepsSim"), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageReturnSim", ), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageEpisodeLengthSim", ), ] real_train_metrics = [ tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesReal"), tf_metrics.EnvironmentSteps(name="EnvironmentStepsReal"), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageReturnReal", ), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageEpisodeLengthReal", ), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup( sim_train_metrics + real_train_metrics, "train_metrics"), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "policy"), policy=eval_policy, global_step=global_step, ) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "replay_buffer"), max_to_keep=1, replay_buffer=(replay_buffer, real_replay_buffer), ) if checkpoint_dir is not None: checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) assert checkpoint_path is not None train_checkpointer._load_status = train_checkpointer._checkpoint.restore( # pylint: disable=protected-access checkpoint_path) train_checkpointer._load_status.initialize_or_restore() # pylint: disable=protected-access else: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if replay_buffer.num_frames() == 0: initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + sim_train_metrics, num_steps=initial_collect_steps, ) real_initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env_real, initial_collect_policy, observers=real_replay_observer + real_train_metrics, num_steps=real_initial_collect_steps, ) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + sim_train_metrics, num_steps=collect_steps_per_iteration, ) real_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env_real, collect_policy, observers=real_replay_observer + real_train_metrics, num_steps=collect_steps_per_iteration, ) config_str = gin.operative_config_str() logging.info(config_str) with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"), "w") as f: f.write(config_str) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) real_initial_collect_driver.run = common.function( real_initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) real_collect_driver.run = common.function(real_collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if replay_buffer.num_frames() == 0: logging.info( "Initializing replay buffer by collecting experience for %d steps with " "a random policy.", initial_collect_steps, ) initial_collect_driver.run() real_initial_collect_driver.run() for eval_name, eval_env, eval_metrics in zip(eval_name_list, eval_env_list, eval_metrics_list): metric_utils.eager_compute( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix="Metrics-%s" % eval_name, ) metric_utils.log_metrics(eval_metrics) time_step = None real_time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = (replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5)) real_dataset = (real_replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5)) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) real_iterator = iter(real_dataset) def train_step(): experience, _ = next(iterator) real_experience, _ = next(real_iterator) return tf_agent.train(experience, real_experience=real_experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) assert not policy_state # We expect policy_state == (). if (global_step.numpy() % real_collect_interval == 0 and global_step.numpy() >= delta_r_warmup): real_time_step, policy_state = real_collect_driver.run( time_step=real_time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info("step = %d, loss = %f", global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info("%.3f steps/sec", steps_per_sec) tf.compat.v2.summary.scalar(name="global_steps_per_sec", data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in sim_train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=sim_train_metrics[:2]) for train_metric in real_train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=real_train_metrics[:2]) if global_step_val % eval_interval == 0: for eval_name, eval_env, eval_metrics in zip( eval_name_list, eval_env_list, eval_metrics_list): metric_utils.eager_compute( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix="Metrics-%s" % eval_name, ) metric_utils.log_metrics(eval_metrics) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_allowlist='position', eval_env_name=None, num_iterations=1000000, # Params for networks. actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), num_parallel_environments=1, # Params for collect initial_collect_episodes=1, collect_episodes_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=256, critic_learning_rate=3e-4, train_sequence_length=20, actor_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for RNN SAC on DM control.""" root_dir = os.path.expanduser(root_dir) summary_writer = tf.compat.v2.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_allowlist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_allowlist=[observations_allowlist]) ] else: env_wrappers = [] env_load_fn = functools.partial(suite_dm_control.load, task_name=task_name, env_wrappers=env_wrappers) if num_parallel_environments == 1: py_env = env_load_fn(env_name) else: py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(eval_env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_rnn_network.CriticRnnNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] env_steps = tf_metrics.EnvironmentSteps(prefix='Train') average_return = tf_metrics.AverageReturnMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(prefix='Train'), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if env_steps.result() == 0 or replay_buffer.num_frames() == 0: logging.info( 'Initializing replay buffer by collecting experience for %d episodes ' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.result()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) time_acc = 0 env_steps_before = env_steps.result().numpy() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): # Reduce filter_fn over full trajectory sampled. The sequence is kept only # if all elements except for the last one pass the filter. This is to # allow training on terminal steps. return tf.reduce_all(~trajectories.is_boundary()[:-1]) dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=train_sequence_length + 1).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() start_env_steps = env_steps.result() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) episode_steps = env_steps.result() - start_env_steps # TODO(b/152648849) for _ in range(episode_steps): for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.compat.v2.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_whitelist='position', num_iterations=100000, actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), # Params for collect initial_collect_steps=1, collect_episodes_per_iteration=1, replay_buffer_capacity=100000, exploration_noise_std=0.1, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=200, batch_size=64, actor_update_period=2, train_sequence_length=10, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, gamma=0.995, reward_scale_factor=1.0, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=10000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_whitelist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_whitelist=[observations_whitelist]) ] else: env_wrappers = [] environment = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(environment) eval_py_env = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) actor_net = actor_rnn_network.ActorRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_rnn_network.CriticRnnNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, ) global_step = tf.compat.v1.train.get_or_create_global_step() tf_agent = td3_agent.Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, gamma=gamma, reward_scale_factor=reward_scale_factor, debug_summaries=debug_summaries, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy policy_state = collect_policy.get_initial_state(tf_env.batch_size) initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps).run(policy_state=policy_state) policy_state = collect_policy.get_initial_state(tf_env.batch_size) collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run( policy_state=policy_state) # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = iterator.get_next() train_fn = common.function(tf_agent.train) train_op = train_fn(experience=trajectories) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, )
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, agent_class=dqn_agent.DqnAgent, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_py_env = suite_gym.load(env_name) q_net = q_network.QNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=fc_layer_params) # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = agent_class( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate), epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] replay_observer = [replay_buffer.add_batch] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_policy = tf_agent.collect_policy collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) experience, _ = iterator.get_next() train_op = common.function(tf_agent.train)(experience=experience) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append(train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) global_step_call = sess.make_callable(global_step) train_step_call = sess.make_callable([train_op, summary_ops]) timed_at_step = global_step_call() collect_time = 0 train_time = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): # Train/collect/eval. start_time = time.time() collect_call() collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() train_time += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = ( (global_step_val - timed_at_step) / (collect_time + train_time)) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info('%.3f steps/sec', steps_per_sec) logging.info('%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )