def __init__(self, root_dir, train_step, agent, experience_dataset_fn=None, after_train_strategy_step_fn=None, triggers=None, checkpoint_interval=100000, summary_interval=1000, max_checkpoints_to_keep=3, use_kwargs_in_agent_train=False, strategy=None): """Initializes a Learner instance. Args: root_dir: Main directory path where checkpoints, saved_models, and summaries will be written to. train_step: a scalar tf.int64 `tf.Variable` which will keep track of the number of train steps. This is used for artifacts created like summaries, or outputs in the root_dir. agent: `tf_agent.TFAgent` instance to train with. experience_dataset_fn: a function that will create an instance of a tf.data.Dataset used to sample experience for training. Required for using the Learner as is. Optional for subclass learners which take a new iterator each time when `learner.run` is called. after_train_strategy_step_fn: (Optional) callable of the form `fn(sample, loss)` which can be used for example to update priorities in a replay buffer where sample is pulled from the `experience_iterator` and loss is a `LossInfo` named tuple returned from the agent. This is called after every train step. It runs using `strategy.run(...)`. triggers: List of callables of the form `trigger(train_step)`. After every `run` call every trigger is called with the current `train_step` value as an np scalar. checkpoint_interval: Number of train steps in between checkpoints. Note these are placed into triggers and so a check to generate a checkpoint only occurs after every `run` call. Set to -1 to disable (this is not recommended, because it means that if the pipeline gets preempted, all previous progress is lost). This only takes care of the checkpointing the training process. Policies must be explicitly exported through triggers. summary_interval: Number of train steps in between summaries. Note these are placed into triggers and so a check to generate a checkpoint only occurs after every `run` call. max_checkpoints_to_keep: Maximum number of checkpoints to keep around. These are used to recover from pre-emptions when training. use_kwargs_in_agent_train: If True the experience from the replay buffer is passed into the agent as kwargs. This requires samples from the RB to be of the form `dict(experience=experience, kwarg1=kwarg1, ...)`. This is useful if you have an agent with a custom argspec. strategy: (Optional) `tf.distribute.Strategy` to use during training. """ if checkpoint_interval < 0: logging.warning( 'Warning: checkpointing the training process is manually disabled.' 'This means training progress will NOT be automatically restored ' 'if the job gets preempted.') self._train_dir = os.path.join(root_dir, TRAIN_DIR) self.train_summary_writer = tf.compat.v2.summary.create_file_writer( self._train_dir, flush_millis=10000) self.train_step = train_step self._agent = agent self.use_kwargs_in_agent_train = use_kwargs_in_agent_train self.strategy = strategy or tf.distribute.get_strategy() if experience_dataset_fn: with self.strategy.scope(): dataset = self.strategy.experimental_distribute_datasets_from_function( lambda _: experience_dataset_fn()) self._experience_iterator = iter(dataset) self.after_train_strategy_step_fn = after_train_strategy_step_fn self.triggers = triggers or [] # Prevent autograph from going into the agent. self._agent.train = tf.autograph.experimental.do_not_convert( agent.train) checkpoint_dir = os.path.join(self._train_dir, POLICY_CHECKPOINT_DIR) with self.strategy.scope(): agent.initialize() self._checkpointer = common.Checkpointer( checkpoint_dir, max_to_keep=max_checkpoints_to_keep, agent=self._agent, train_step=self.train_step) self._checkpointer.initialize_or_restore() # pytype: disable=attribute-error self.triggers.append(self._get_checkpoint_trigger(checkpoint_interval)) self.summary_interval = tf.constant(summary_interval, dtype=tf.int64)
def train_eval( root_dir, env_name='MaskedCartPole-v0', num_iterations=100000, input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), train_sequence_length=10, # Params for collect initial_collect_steps=50, collect_episodes_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=10, batch_size=128, learning_rate=1e-3, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=100, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): eval_py_env = suite_gym.load(env_name) tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) q_net = q_rnn_network.QRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839 epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps).run() collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run() # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) experience, _ = iterator.get_next() loss_info = tf_agent.train(experience=experience) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) logging.info('Collecting initial experience.') sess.run(initial_collect_op) # Compute evaluation metrics. global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable(loss_info) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): # Train/collect/eval. start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, )
def train(): summary_interval = 1000 summaries_flush_secs = 10 num_eval_episodes = 5 root_dir = '/tmp/tensorflow/logs/tfenv01' train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs*1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs*1000) # maybe py_metrics? eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] environment = TradeEnvironment() # utils.validate_py_environment(environment, episodes=5) # Environments global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): train_env = tf_py_environment.TFPyEnvironment(environment) eval_env = tf_py_environment.TFPyEnvironment(environment) num_iterations = 50 fc_layer_params = (512, ) # ~ (17 + 1001) / 2 input_fc_layer_params = (50, ) output_fc_layer_params = (20, ) lstm_size = (30, ) initial_collect_steps = 20 collect_steps_per_iteration = 1 collect_episodes_per_iteration = 1 # the same as above batch_size = 64 replay_buffer_capacity = 10000 train_sequence_length = 10 gamma = 0.99 # check if 1.0 works as well target_update_tau = 0.05 target_update_period = 5 epsilon_greedy = 0.1 gradient_clipping = None reward_scale_factor = 1.0 learning_rate = 1e-2 log_interval = 30 eval_interval = 15 # train_env.observation_spec(), q_net = q_rnn_network.QRnnNetwork( train_env.time_step_spec().observation, train_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params, ) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) tf_agent = dqn_agent.DqnAgent( train_env.time_step_spec(), train_env.action_spec(), q_network=q_net, optimizer=optimizer, epsilon_greedy=epsilon_greedy, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=dqn_agent.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=global_step, ) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=train_env.batch_size, max_length=replay_buffer_capacity, ) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Policy which does not allow some actions in certain states q_policy = FilteredQPolicy( tf_agent._time_step_spec, tf_agent._action_spec, q_network=tf_agent._q_network, ) # Valid policy to pre-fill replay buffer initial_collect_policy = DummyTradePolicy( train_env.time_step_spec(), train_env.action_spec(), ) print('Initial collecting...') initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( train_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps, ).run() # Main agent's policy; greedy one policy = greedy_policy.GreedyPolicy(q_policy) # Policy used for evaluation, the same as above eval_policy = greedy_policy.GreedyPolicy(q_policy) tf_agent._policy = policy collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy( q_policy, epsilon=tf_agent._epsilon_greedy) # Patch random policy for epsilon greedy collect policy filtered_random_tf_policy = FilteredRandomTFPolicy( time_step_spec=policy.time_step_spec, action_spec=policy.action_spec, ) collect_policy._random_policy = filtered_random_tf_policy tf_agent._collect_policy = collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( train_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration, ).run() dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length+1, ).prefetch(3) iterator = iter(dataset) experience, _ = next(iterator) loss_info = common.function(tf_agent.train)(experience=experience) # Checkpoints train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step, ) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer, ) summary_ops = [] for train_metric in train_metrics: summary_ops.append(train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2], )) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # sess.run(train_summary_writer.init()) # sess.run(eval_summary_writer.init()) # Initialize the graph # tfe.Saver().restore() # train_checkpointer.initialize_or_restore() # rb_checkpointer.initialize_or_restore() # sess.run(iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) print('Collecting initial experience...') sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([loss_info, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step, ) # Train for i in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_inerval == 0: print('step=%d, loss=%f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val-timed_at_step) / time_acc print('%.3f steps/sec', steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}, ) timed_at_step = global_step_val time_acc = 0 # Save checkpoints if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) # Evaluate if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) print('Done!')
def train( root_dir, load_root_dir=None, env_load_fn=None, env_name=None, num_parallel_environments=1, # pylint: disable=unused-argument agent_class=None, initial_collect_random=True, # pylint: disable=unused-argument initial_collect_driver_class=None, collect_driver_class=None, num_global_steps=1000000, train_steps_per_iteration=1, train_metrics=None, # Safety Critic training args train_sc_steps=10, train_sc_interval=300, online_critic=False, # Params for eval run_eval=False, num_eval_episodes=30, eval_interval=1000, eval_metrics_callback=None, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, keep_rb_checkpoint=False, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, early_termination_fn=None, env_metric_factories=None): # pylint: disable=unused-argument """A simple train and eval for SC-SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() train_metrics = train_metrics or [] if run_eval: eval_dir = os.path.join(root_dir, 'eval') eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = env_load_fn(env_name) if not isinstance(tf_env, tf_py_environment.TFPyEnvironment): tf_env = tf_py_environment.TFPyEnvironment(tf_env) if run_eval: eval_py_env = env_load_fn(env_name) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() print('obs spec:', observation_spec) print('action spec:', action_spec) if online_critic: resample_metric = tf_py_metric.TfPyMetric( py_metrics.CounterMetric('unsafe_ac_samples')) tf_agent = agent_class(time_step_spec, action_spec, train_step_counter=global_step, resample_metric=resample_metric) else: tf_agent = agent_class(time_step_spec, action_spec, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. collect_data_spec = tf_agent.collect_data_spec logging.info('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, max_length=1000000) logging.info('RB capacity: %i', replay_buffer.capacity) logging.info('ReplayBuffer Collect data spec: %s', collect_data_spec) agent_observers = [replay_buffer.add_batch] if online_critic: online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, max_length=10000) online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer') online_rb_checkpointer = common.Checkpointer( ckpt_dir=online_rb_ckpt_dir, max_to_keep=1, replay_buffer=online_replay_buffer) clear_rb = common.function(online_replay_buffer.clear) agent_observers.append(online_replay_buffer.add_batch) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] if not online_critic: eval_policy = tf_agent.policy else: eval_policy = tf_agent._safe_policy # pylint: disable=protected-access initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) if not online_critic: collect_policy = tf_agent.collect_policy else: collect_policy = tf_agent._safe_policy # pylint: disable=protected-access train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) safety_critic_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'safety_critic'), safety_critic=tf_agent._safety_critic_network, # pylint: disable=protected-access global_step=global_step) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) if load_root_dir: load_root_dir = os.path.expanduser(load_root_dir) load_train_dir = os.path.join(load_root_dir, 'train') misc.load_pi_ckpt(load_train_dir, tf_agent) # loads tf_agent if load_root_dir is None: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() safety_critic_checkpointer.initialize_or_restore() collect_driver = collect_driver_class(tf_env, collect_policy, observers=agent_observers + train_metrics) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if not rb_checkpointer.checkpoint_exists: logging.info('Performing initial collection ...') common.function( initial_collect_driver_class(tf_env, initial_collect_policy, observers=agent_observers + train_metrics).run)() last_id = replay_buffer._get_last_id() # pylint: disable=protected-access logging.info('Data saved after initial collection: %d steps', last_id) tf.print( replay_buffer._get_rows_for_id(last_id), # pylint: disable=protected-access output_stream=logging.info) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if FLAGS.viz_pm: eval_fig_dir = osp.join(eval_dir, 'figs') if not tf.io.gfile.isdir(eval_fig_dir): tf.io.gfile.makedirs(eval_fig_dir) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2).prefetch(3) iterator = iter(dataset) if online_critic: online_dataset = online_replay_buffer.as_dataset( num_parallel_calls=3, num_steps=2).prefetch(3) online_iterator = iter(online_dataset) @common.function def critic_train_step(): """Builds critic training step.""" experience, buf_info = next(online_iterator) if env_name in [ 'IndianWell', 'IndianWell2', 'IndianWell3', 'DrunkSpider', 'DrunkSpiderShort' ]: safe_rew = experience.observation['task_agn_rew'] else: safe_rew = agents.process_replay_buffer( online_replay_buffer, as_tensor=True) safe_rew = tf.gather(safe_rew, tf.squeeze(buf_info.ids), axis=1) ret = tf_agent.train_sc(experience, safe_rew) clear_rb() return ret @common.function def train_step(): experience, _ = next(iterator) ret = tf_agent.train(experience) return ret if not early_termination_fn: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss') if online_critic: mean_resample_ac = tf.keras.metrics.Mean( name='mean_unsafe_ac_samples') resample_metric.reset() while (global_step.numpy() <= num_global_steps and not early_termination_fn()): # Collect and train. start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) if online_critic: mean_resample_ac(resample_metric.result()) resample_metric.reset() if time_step.is_last(): resample_ac_freq = mean_resample_ac.result() mean_resample_ac.reset_states() tf.compat.v2.summary.scalar(name='unsafe_ac_samples', data=resample_ac_freq, step=global_step) for _ in range(train_steps_per_iteration): train_loss = train_step() mean_train_loss(train_loss.loss) if online_critic: if global_step.numpy() % train_sc_interval == 0: for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # pylint: disable=unused-variable total_loss = mean_train_loss.result() mean_train_loss.reset_states() # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True break else: loss_divergence_counter = 0 time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), total_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) safety_critic_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: if online_critic: online_rb_checkpointer.save(global_step=global_step_val) rb_checkpointer.save(global_step=global_step_val) if run_eval and global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if FLAGS.viz_pm: savepath = 'step{}.png'.format(global_step_val) savepath = osp.join(eval_fig_dir, savepath) misc.record_episode_vis_summary(eval_tf_env, eval_policy, savepath) if not keep_rb_checkpoint: misc.cleanup_checkpoints(rb_ckpt_dir) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy())) return total_loss
def train_eval( root_dir, env_name='HalfCheetah-v2', eval_env_name=None, env_load_fn=suite_mujoco.load, num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, num_parallel_environments=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if num_parallel_environments > 1: tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) else: tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_env_name = eval_env_name or env_name eval_py_env = env_load_fn(eval_env_name) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = iterator.get_next() train_fn = common.function(tf_agent.train) train_op = train_fn(experience=trajectories) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(b/126239733) Remove once Periodically can be saved. common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = sess.run(global_step) time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, )
def train_eval( root_dir, env_name='MinitaurGoalVelocityEnv-v0', eval_env_name=None, env_load_fn=suite_pybullet.load, num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), ensemble=True, n_critics=10, run_eval=False, # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.keras.losses.mse, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=1., use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') if run_eval: eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() if run_eval: eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(eval_env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers) if ensemble: critic_nets, critic_optimizers = [], [] for _ in range(n_critics): critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) critic_optimizers.append( tf.keras.optimizers.Adam( learning_rate=critic_learning_rate)) critic_nets.append(critic_net) tf_agent = ensemble_sac_agent.EnsembleSacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_networks=critic_nets, actor_optimizer=tf.keras.optimizers.Adam( learning_rate=actor_learning_rate), critic_optimizers=critic_optimizers, alpha_optimizer=tf.keras.optimizers.Adam( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) else: critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.keras.optimizers.Adam( learning_rate=actor_learning_rate), critic_optimizer=tf.keras.optimizers.Adam( learning_rate=critic_learning_rate), alpha_optimizer=tf.keras.optimizers.Adam( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_py_metric.TFPyMetric( metrics.AverageEarlyFailureMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size)) ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) tf.estimator.SessionRunHook() # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) if not rb_checkpointer.checkpoint_exists: initial_collect_driver.run() if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) while global_step.numpy() < num_iterations: start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0 and run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
# Sample a batch of data from the buffer and update the agent's network. iterator = iter(dataset) experience, unused_info = next(iterator) train_loss = agent.train(experience).loss log_interval = 5 eval_interval = 5 if step % log_interval == 0: print('step = {0}: loss = {1}'.format(step, train_loss)) if step % 150 == 0 and step >= 9000: policy_dir = os.path.join('waypoints', 'DOUBLE REWARDS_POLICY') tf_policy_saver = policy_saver.PolicySaver(agent.policy) tf_policy_saver.save(policy_dir) checkpoint_dir = os.path.join('waypoints', 'Double rewards_CP') train_checkpointer = common.Checkpointer( ckpt_dir=checkpoint_dir, max_to_keep=1, agent=agent, policy=agent.policy, replay_buffer=replay_buffer, ) plt.plot(returns) plt.grid() plt.show()
def train_eval( root_dir, offline_dir=None, random_seed=None, env_name='sawyer_push', eval_env_name=None, env_load_fn=get_env, max_episode_steps=1000, eval_episode_steps=1000, # The SAC paper reported: # Hopper and Cartpole results up to 1000000 iters, # Humanoid results up to 10000000 iters, # Other mujoco tasks up to 3000000 iters. num_iterations=3000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py # HalfCheetah and Ant take 10000 initial collection steps. # Other mujoco tasks take 1000. # Different choices roughly keep the initial episodes about the same. initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train reset_goal_frequency=1000, # virtual episode size for reset-free training train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, # reset-free parameters use_minimum=True, reset_lagrange_learning_rate=3e-4, value_threshold=None, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, # Td3 parameters actor_update_period=1, exploration_noise_std=0.1, target_policy_noise=0.1, target_policy_noise_clip=0.1, dqda_clipping=None, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, # video recording for the environment video_record_interval=10000, num_videos=0, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): start_time = time.time() root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') video_dir = os.path.join(eval_dir, 'videos') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) if FLAGS.use_reset_goals in [-1]: gym_env_wrappers = (functools.partial( reset_free_wrapper.GoalTerminalResetWrapper, num_success_states=FLAGS.num_success_states, full_reset_frequency=max_episode_steps),) elif FLAGS.use_reset_goals in [0, 1]: gym_env_wrappers = (functools.partial( reset_free_wrapper.ResetFreeWrapper, reset_goal_frequency=reset_goal_frequency, variable_horizon_for_reset=FLAGS.variable_reset_horizon, num_success_states=FLAGS.num_success_states, full_reset_frequency=max_episode_steps),) elif FLAGS.use_reset_goals in [2]: gym_env_wrappers = (functools.partial( reset_free_wrapper.CustomOracleResetWrapper, partial_reset_frequency=reset_goal_frequency, episodes_before_full_reset=max_episode_steps // reset_goal_frequency),) elif FLAGS.use_reset_goals in [3, 4]: gym_env_wrappers = (functools.partial( reset_free_wrapper.GoalTerminalResetFreeWrapper, reset_goal_frequency=reset_goal_frequency, num_success_states=FLAGS.num_success_states, full_reset_frequency=max_episode_steps),) elif FLAGS.use_reset_goals in [5, 7]: gym_env_wrappers = (functools.partial( reset_free_wrapper.CustomOracleResetGoalTerminalWrapper, partial_reset_frequency=reset_goal_frequency, episodes_before_full_reset=max_episode_steps // reset_goal_frequency),) elif FLAGS.use_reset_goals in [6]: gym_env_wrappers = (functools.partial( reset_free_wrapper.VariableGoalTerminalResetWrapper, full_reset_frequency=max_episode_steps),) if env_name == 'playpen_reduced': train_env_load_fn = functools.partial( env_load_fn, reset_at_goal=FLAGS.reset_at_goal) else: train_env_load_fn = env_load_fn env, env_train_metrics, env_eval_metrics, aux_info = train_env_load_fn( name=env_name, max_episode_steps=None, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(env) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(name=eval_env_name, max_episode_steps=eval_episode_steps)[0]) eval_metrics += env_eval_metrics time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() if FLAGS.agent_type == 'sac': actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=functools.partial( tanh_normal_projection_network.TanhNormalProjectionNetwork, std_transform=std_clip_transform)) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform', ) critic_net_no_entropy = None critic_no_entropy_optimizer = None if FLAGS.use_no_entropy_q: critic_net_no_entropy = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform', name='CriticNetworkNoEntropy1') critic_no_entropy_optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate) tf_agent = SacAgent( time_step_spec, action_spec, num_action_samples=FLAGS.num_action_samples, actor_network=actor_net, critic_network=critic_net, critic_network_no_entropy=critic_net_no_entropy, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), critic_no_entropy_optimizer=critic_no_entropy_optimizer, alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif FLAGS.agent_type == 'td3': actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, target_policy_noise=target_policy_noise, target_policy_noise_clip=target_policy_noise_clip, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() if FLAGS.use_reset_goals > 0: if FLAGS.use_reset_goals in [4, 5, 6]: reset_goal_generator = ScheduledResetGoal( goal_dim=aux_info['reset_state_shape'][0], num_success_for_switch=FLAGS.num_success_for_switch, num_chunks=FLAGS.num_chunks, name='ScheduledResetGoalGenerator') else: # distance to initial state distribution initial_state_distance = state_distribution_distance.L2Distance( initial_state_shape=aux_info['reset_state_shape']) initial_state_distance.update( tf.constant(aux_info['reset_states'], dtype=tf.float32), update_type='complete') if use_tf_functions: initial_state_distance.distance = common.function( initial_state_distance.distance) tf_agent.compute_value = common.function(tf_agent.compute_value) # initialize reset / practice goal proposer if reset_lagrange_learning_rate > 0: reset_goal_generator = ResetGoalGenerator( goal_dim=aux_info['reset_state_shape'][0], compute_value_fn=tf_agent.compute_value, distance_fn=initial_state_distance, use_minimum=use_minimum, value_threshold=value_threshold, lagrange_variable_max=FLAGS.lagrange_max, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=reset_lagrange_learning_rate), name='reset_goal_generator') else: reset_goal_generator = FixedResetGoal( distance_fn=initial_state_distance) # if use_tf_functions: # reset_goal_generator.get_reset_goal = common.function( # reset_goal_generator.get_reset_goal) # modify the reset-free wrapper to use the reset goal generator tf_env.pyenv.envs[0].set_reset_goal_fn( reset_goal_generator.get_reset_goal) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] if FLAGS.relabel_goals: cur_episode_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, scope='CurEpisodeReplayBuffer', max_length=int(2 * min(reset_goal_frequency, max_episode_steps))) # NOTE: the buffer is replaced because cannot have two buffers.add_batch replay_observer = [cur_episode_buffer.add_batch] # initialize metrics and observers train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] train_metrics += env_train_metrics eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) eval_py_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_agent.policy, use_tf_function=True) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if offline_dir is not None: offline_data = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=int(1e5)) # this has to be 100_000 offline_checkpointer = common.Checkpointer( ckpt_dir=offline_dir, max_to_keep=1, replay_buffer=offline_data) offline_checkpointer.initialize_or_restore() # set the reset candidates to be all the data in offline buffer if (FLAGS.use_reset_goals > 0 and reset_lagrange_learning_rate > 0) or FLAGS.use_reset_goals in [ 4, 5, 6, 7 ]: tf_env.pyenv.envs[0].set_reset_candidates( nest_utils.unbatch_nested_tensors(offline_data.gather_all())) if replay_buffer.num_frames() == 0: if offline_dir is not None: copy_replay_buffer(offline_data, replay_buffer) print(replay_buffer.num_frames()) # multiply offline data if FLAGS.relabel_offline_data: data_multiplier(replay_buffer, tf_env.pyenv.envs[0].env.compute_reward) print('after data multiplication:', replay_buffer.num_frames()) initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=1) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) for iter_idx in range(initial_collect_steps): time_step, policy_state = initial_collect_driver.run( time_step=time_step, policy_state=policy_state) if time_step.is_last() and FLAGS.relabel_goals: reward_fn = tf_env.pyenv.envs[0].env.compute_reward relabel_function(cur_episode_buffer, time_step, reward_fn, replay_buffer) cur_episode_buffer.clear() if FLAGS.use_reset_goals > 0 and time_step.is_last( ) and FLAGS.num_reset_candidates > 0: tf_env.pyenv.envs[0].set_reset_candidates( replay_buffer.get_next( sample_batch_size=FLAGS.num_reset_candidates)[0]) else: time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) # manual data save for plotting utils np_custom_save(os.path.join(eval_dir, 'eval_interval.npy'), eval_interval) try: average_eval_return = np_custom_load( os.path.join(eval_dir, 'average_eval_return.npy')).tolist() average_eval_success = np_custom_load( os.path.join(eval_dir, 'average_eval_success.npy')).tolist() average_eval_final_success = np_custom_load( os.path.join(eval_dir, 'average_eval_final_success.npy')).tolist() except: # pylint: disable=bare-except average_eval_return = [] average_eval_success = [] average_eval_final_success = [] print('initialization_time:', time.time() - start_time) for iter_idx in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) if time_step.is_last() and FLAGS.relabel_goals: reward_fn = tf_env.pyenv.envs[0].env.compute_reward relabel_function(cur_episode_buffer, time_step, reward_fn, replay_buffer) cur_episode_buffer.clear() # reset goal generator updates if FLAGS.use_reset_goals > 0 and iter_idx % ( FLAGS.reset_goal_frequency * collect_steps_per_iteration) == 0: if FLAGS.num_reset_candidates > 0: tf_env.pyenv.envs[0].set_reset_candidates( replay_buffer.get_next( sample_batch_size=FLAGS.num_reset_candidates)[0]) if reset_lagrange_learning_rate > 0: reset_goal_generator.update_lagrange_multipliers() for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in train_metrics: if 'Heatmap' in train_metric.name: if global_step_val % summary_interval == 0: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) else: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) if global_step_val % summary_interval == 0 and FLAGS.use_reset_goals > 0 and reset_lagrange_learning_rate > 0: reset_states, values, initial_state_distance_vals, lagrangian = reset_goal_generator.update_summaries( step_counter=global_step) for vf_viz_metric in aux_info['value_fn_viz_metrics']: vf_viz_metric.tf_summaries( reset_states, values, train_step=global_step, step_metrics=train_metrics[:2]) if FLAGS.debug_value_fn_for_reset: num_test_lagrange = 20 hyp_lagranges = [ 1.0 * increment / num_test_lagrange for increment in range(num_test_lagrange + 1) ] door_pos = reset_states[ np.argmin(initial_state_distance_vals.numpy() - lagrangian.numpy() * values.numpy())][3:5] print('cur lagrange: %.2f, cur reset goal: (%.2f, %.2f)' % (lagrangian.numpy(), door_pos[0], door_pos[1])) for lagrange in hyp_lagranges: door_pos = reset_states[ np.argmin(initial_state_distance_vals.numpy() - lagrange * values.numpy())][3:5] print('test lagrange: %.2f, cur reset goal: (%.2f, %.2f)' % (lagrange, door_pos[0], door_pos[1])) print('\n') if global_step_val % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) # numpy saves for plotting if 'AverageReturn' in results.keys(): average_eval_return.append(results['AverageReturn'].numpy()) if 'EvalSuccessfulAtAnyStep' in results.keys(): average_eval_success.append( results['EvalSuccessfulAtAnyStep'].numpy()) if 'EvalSuccessfulEpisodes' in results.keys(): average_eval_final_success.append( results['EvalSuccessfulEpisodes'].numpy()) elif 'EvalSuccessfulAtLastStep' in results.keys(): average_eval_final_success.append( results['EvalSuccessfulAtLastStep'].numpy()) if average_eval_return: np_custom_save( os.path.join(eval_dir, 'average_eval_return.npy'), average_eval_return) if average_eval_success: np_custom_save( os.path.join(eval_dir, 'average_eval_success.npy'), average_eval_success) if average_eval_final_success: np_custom_save( os.path.join(eval_dir, 'average_eval_final_success.npy'), average_eval_final_success) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % video_record_interval == 0: for video_idx in range(num_videos): video_name = os.path.join(video_dir, str(global_step_val), 'video_' + str(video_idx) + '.mp4') record_video( lambda: env_load_fn( # pylint: disable=g-long-lambda name=env_name, max_episode_steps=max_episode_steps)[0], video_name, eval_py_policy, max_episode_length=eval_episode_steps) return train_loss
def main(cfg): # Set up logging and checkpointing log_dir = Path(cfg.log_dir) checkpoint_dir = Path(cfg.checkpoint_dir) print('log_dir: {}'.format(log_dir)) print('checkpoint_dir: {}'.format(checkpoint_dir)) # Create env env = utils.get_env_from_cfg(cfg) tf_env = components.get_tf_py_env(env, cfg.num_input_channels) # Agents epsilon = tf.Variable(1.0) agents = [] for i, g in enumerate(cfg.robot_config): robot_type = next(iter(g)) q_net = components.QNetwork( tf_env.observation_spec(), num_output_channels=VectorEnv.get_num_output_channels(robot_type)) optimizer = keras.optimizers.SGD( learning_rate=cfg.learning_rate, momentum=0.9) # cfg.weight_decay is currently ignored agent_cls = dqn_agent.DdqnAgent if cfg.use_double_dqn else dqn_agent.DqnAgent agent = agent_cls( time_step_spec=tf_env.time_step_spec(), action_spec=components.get_action_spec(robot_type), q_network=q_net, optimizer=optimizer, epsilon_greedy=epsilon, target_update_period=(cfg.target_update_freq // cfg.train_freq), td_errors_loss_fn=common.element_wise_huber_loss, gamma=cfg.discount_factors[i], gradient_clipping=cfg.grad_norm_clipping, train_step_counter=tf.Variable( 0, dtype=tf.int64), # Separate counter for each agent ) agent.initialize() agent.train = common.function(agent.train) agents.append(agent) global_step = agents[0].train_step_counter # Replay buffers replay_buffers = [ReplayBuffer(cfg.replay_buffer_size) for _ in agents] # Checkpointing timestep_var = tf.Variable(0, dtype=tf.int64) agent_checkpointer = common.Checkpointer(ckpt_dir=str(checkpoint_dir / 'agents'), max_to_keep=5, agents=agents, timestep_var=timestep_var) agent_checkpointer.initialize_or_restore() if timestep_var.numpy() > 0: checkpoint_path = checkpoint_dir / 'checkpoint_{:08d}.pkl'.format( timestep_var.numpy()) with open(checkpoint_path, 'rb') as f: replay_buffers = pickle.load(f) # Logging train_summary_writer = tf.summary.create_file_writer(str(log_dir / 'train')) train_summary_writer.set_as_default() time_step = tf_env.reset() learning_starts = round(cfg.learning_starts_frac * cfg.total_timesteps) total_timesteps_with_warm_up = learning_starts + cfg.total_timesteps start_timestep = timestep_var.numpy() for timestep in tqdm(range(start_timestep, total_timesteps_with_warm_up), initial=start_timestep, total=total_timesteps_with_warm_up, file=sys.stdout): # Set exploration epsilon exploration_eps = 1 - (1 - cfg.final_exploration) * min( 1, max(0, timestep - learning_starts) / (cfg.exploration_frac * cfg.total_timesteps)) epsilon.assign(exploration_eps) # Run one collect step transitions_per_buffer = tf_env.pyenv.envs[0].store_time_step( time_step) robot_group_index = tf_env.pyenv.envs[0].current_robot_group_index() action_step = agents[robot_group_index].collect_policy.action( time_step) time_step = tf_env.step(action_step.action) # Store experience in buffers for i, transitions in enumerate(transitions_per_buffer): for transition in transitions: replay_buffers[i].push(*transition) # Train policies if timestep >= learning_starts and (timestep + 1) % cfg.train_freq == 0: for i, agent in enumerate(agents): experience = replay_buffers[i].sample(cfg.batch_size) agent.train(experience) # Logging if tf_env.pyenv.envs[0].done(): info = tf_env.pyenv.envs[0].get_info() tf.summary.scalar('timesteps', timestep + 1, global_step) tf.summary.scalar('steps', info['steps'], global_step) tf.summary.scalar('total_cubes', info['total_cubes'], global_step) # Checkpointing if ( timestep + 1 ) % cfg.checkpoint_freq == 0 or timestep + 1 == total_timesteps_with_warm_up: # Save agents timestep_var.assign(timestep + 1) agent_checkpointer.save(timestep + 1) # Save replay buffers checkpoint_path = checkpoint_dir / 'checkpoint_{:08d}.pkl'.format( timestep + 1) with open(checkpoint_path, 'wb') as f: pickle.dump(replay_buffers, f) cfg.checkpoint_path = str(checkpoint_path) utils.save_config(log_dir / 'config.yml', cfg) # Remove old checkpoints checkpoint_paths = list(checkpoint_dir.glob('checkpoint_*.pkl')) checkpoint_paths.remove(checkpoint_path) for old_checkpoint_path in checkpoint_paths: old_checkpoint_path.unlink() # Export trained policies policy_dir = checkpoint_dir / 'policies' for i, agent in enumerate(agents): policy_saver.PolicySaver(agent.policy).save( str(policy_dir / 'robot_group_{:02}'.format(i + 1))) cfg.policy_path = str(policy_dir) utils.save_config(log_dir / 'config.yml', cfg) env.close()
def load_agents_and_create_videos( root_dir, env_name='CartPole-v0', num_iterations=NUM_ITERATIONS, max_ep_steps=1000, train_sequence_length=1, # Params for QNetwork fc_layer_params=((100, )), # Params for QRnnNetwork input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, num_atoms=51, min_q_value=-20, max_q_value=20, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, num_random_episodes=1, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None, random_metrics_callback=None): # Define the directories to read from train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') random_dir = os.path.join(root_dir, 'random') # Match the writers and metrics used in training train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] random_summary_writer = tf.compat.v2.summary.create_file_writer( random_dir, flush_millis=summaries_flush_secs * 1000) random_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() # Match the environments used in training tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name, max_episode_steps=max_ep_steps)) eval_py_env = suite_gym.load(env_name, max_episode_steps=max_ep_steps) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) # Match the agents used in training categorical_q_net = categorical_q_network.CategoricalQNetwork( tf_env.observation_spec(), tf_env.action_spec(), num_atoms=num_atoms, fc_layer_params=fc_layer_params) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) tf_agent = categorical_dqn_agent.CategoricalDqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), categorical_q_network=categorical_q_net, optimizer=optimizer, min_q_value=min_q_value, max_q_value=max_q_value, n_step_update=n_step_update, td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ # tf_metrics.NumberOfEpisodes(), # tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer(ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup( train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) random_policy = random_tf_policy.RandomTFPolicy( eval_tf_env.time_step_spec(), eval_tf_env.action_spec()) # Make movies of the trained agent and a random agent date_string = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S') # Finally, used the saved policy to generate the video trained_filename = "trainedC51_" + date_string create_policy_eval_video(eval_tf_env, eval_py_env, tf_agent.policy, trained_filename) # And, create one with a random agent for comparison random_filename = 'random_' + date_string create_policy_eval_video(eval_tf_env, eval_py_env, random_policy, random_filename)
def learn(self, num_iterations=100000): dataset = self.replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=self.batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) collect_driver = dynamic_step_driver.DynamicStepDriver( self.train_env, self.tf_agent.collect_policy, observers=self.replay_observer + self.train_metrics, num_steps=self.collect_steps_per_iteration) root_dir = self.log_dir root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') checkpoint_dir = os.path.join(root_dir, 'checkpoint') policy_dir = os.path.join(root_dir, 'policy') best_policy_dir = os.path.join(root_dir, 'best_policy') saver_policy = policy_saver.PolicySaver(self.tf_agent.policy) train_checkpointer = common.Checkpointer( ckpt_dir=checkpoint_dir, max_to_keep=2, agent=self.tf_agent, policy=self.tf_agent.policy, replay_buffer=self.replay_buffer, global_step=self.global_step) if (not self.resume_training): train_summary_writer = tf.summary.create_file_writer( train_dir, flush_millis=self.summaries_flush_secs * 1000) train_summary_writer.set_as_default() # Reset the train step self.tf_agent.train_step_counter.assign(0) else: train_checkpointer.initialze_or_restore() self.global_step = tf.compat.v1.train.get_global_step() print('Resume global step: ', self.global_step) # (Optional) Optimize by wrapping some of the code in a graph using TF function. self.tf_agent.train = common.function(self.tf_agent.train) collect_driver.run = common.function(collect_driver.run) with tf.summary.record_if(lambda: tf.math.equal( self.global_step % self.summary_interval, 0)): for _ in tqdm(range(num_iterations)): collect_driver.run() experience, unused_info = next(iterator) train_loss = self.tf_agent.train(experience) for train_metric in self.train_metrics: train_metric.tf_summaries(train_step=self.global_step, step_metrics=self.step_metrics) step = self.tf_agent.train_step_counter.numpy() if ((step % 100000) == 0): train_checkpointer.save(self.global_step) saver_policy.save(policy_dir) if (step % self.summary_interval == 0): avg = self.compute_avg_return(self.num_eval_episodes) tf.summary.scalar('Average Reward', avg, step=self.global_step) if (avg > self.return_avg): saver_policy.save(best_policy_dir) self.return_avg = avg
def train_eval( root_dir, environment_name="broken_reacher", num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), initial_collect_steps=10000, real_initial_collect_steps=10000, collect_steps_per_iteration=1, real_collect_interval=10, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, classifier_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=False, train_on_real=False, delta_r_warmup=0, random_seed=0, checkpoint_dir=None, ): """A simple train and eval for SAC.""" np.random.seed(random_seed) tf.random.set_seed(random_seed) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, "train") eval_dir = os.path.join(root_dir, "eval") train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) if environment_name == "broken_reacher": get_env_fn = darc_envs.get_broken_reacher_env elif environment_name == "half_cheetah_obstacle": get_env_fn = darc_envs.get_half_cheetah_direction_env elif environment_name.startswith("broken_joint"): base_name = environment_name.split("broken_joint_")[1] get_env_fn = functools.partial( darc_envs.get_broken_joint_env, env_name=base_name) elif environment_name.startswith("falling"): base_name = environment_name.split("falling_")[1] get_env_fn = functools.partial( darc_envs.get_falling_env, env_name=base_name) else: raise NotImplementedError("Unknown environment: %s" % environment_name) eval_name_list = ["sim", "real"] eval_env_list = [get_env_fn(mode) for mode in eval_name_list] eval_metrics_list = [] for name in eval_name_list: eval_metrics_list.append([ tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, name="AverageReturn_%s" % name), ]) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env_real = get_env_fn("real") if train_on_real: tf_env = get_env_fn("real") else: tf_env = get_env_fn("sim") time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=( tanh_normal_projection_network.TanhNormalProjectionNetwork), ) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer="glorot_uniform", last_kernel_initializer="glorot_uniform", ) classifier = classifiers.build_classifier(observation_spec, action_spec) tf_agent = darc_agent.DarcAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, classifier=classifier, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), classifier_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=classifier_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, ) replay_observer = [replay_buffer.add_batch] real_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, ) real_replay_observer = [real_replay_buffer.add_batch] sim_train_metrics = [ tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesSim"), tf_metrics.EnvironmentSteps(name="EnvironmentStepsSim"), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageReturnSim", ), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageEpisodeLengthSim", ), ] real_train_metrics = [ tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesReal"), tf_metrics.EnvironmentSteps(name="EnvironmentStepsReal"), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageReturnReal", ), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageEpisodeLengthReal", ), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup( sim_train_metrics + real_train_metrics, "train_metrics"), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "policy"), policy=eval_policy, global_step=global_step, ) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "replay_buffer"), max_to_keep=1, replay_buffer=(replay_buffer, real_replay_buffer), ) if checkpoint_dir is not None: checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) assert checkpoint_path is not None train_checkpointer._load_status = train_checkpointer._checkpoint.restore( # pylint: disable=protected-access checkpoint_path) train_checkpointer._load_status.initialize_or_restore() # pylint: disable=protected-access else: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if replay_buffer.num_frames() == 0: initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + sim_train_metrics, num_steps=initial_collect_steps, ) real_initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env_real, initial_collect_policy, observers=real_replay_observer + real_train_metrics, num_steps=real_initial_collect_steps, ) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + sim_train_metrics, num_steps=collect_steps_per_iteration, ) real_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env_real, collect_policy, observers=real_replay_observer + real_train_metrics, num_steps=collect_steps_per_iteration, ) config_str = gin.operative_config_str() logging.info(config_str) with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"), "w") as f: f.write(config_str) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) real_initial_collect_driver.run = common.function( real_initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) real_collect_driver.run = common.function(real_collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if replay_buffer.num_frames() == 0: logging.info( "Initializing replay buffer by collecting experience for %d steps with " "a random policy.", initial_collect_steps, ) initial_collect_driver.run() real_initial_collect_driver.run() for eval_name, eval_env, eval_metrics in zip(eval_name_list, eval_env_list, eval_metrics_list): metric_utils.eager_compute( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix="Metrics-%s" % eval_name, ) metric_utils.log_metrics(eval_metrics) time_step = None real_time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = ( replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5)) real_dataset = ( real_replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5)) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) real_iterator = iter(real_dataset) def train_step(): experience, _ = next(iterator) real_experience, _ = next(real_iterator) return tf_agent.train(experience, real_experience=real_experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) assert not policy_state # We expect policy_state == (). if (global_step.numpy() % real_collect_interval == 0 and global_step.numpy() >= delta_r_warmup): real_time_step, policy_state = real_collect_driver.run( time_step=real_time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info("step = %d, loss = %f", global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info("%.3f steps/sec", steps_per_sec) tf.compat.v2.summary.scalar( name="global_steps_per_sec", data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in sim_train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=sim_train_metrics[:2]) for train_metric in real_train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=real_train_metrics[:2]) if global_step_val % eval_interval == 0: for eval_name, eval_env, eval_metrics in zip(eval_name_list, eval_env_list, eval_metrics_list): metric_utils.eager_compute( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix="Metrics-%s" % eval_name, ) metric_utils.log_metrics(eval_metrics) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
print(tf_agent.collect_data_spec) print('Replay Buffer Created, start warming-up ...') _startTime = dt.datetime.now() # driver for warm-up # https://www.tensorflow.org/agents/api_docs/python/tf_agents/drivers/dynamic_episode_driver/DynamicEpisodeDriver initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( env, collect_policy, observers=[replay_buffer.add_batch], num_episodes=warmupEpisodes) # run restore process if shouldContinueFromLastCheckpoint: train_checkpointer = common.Checkpointer(ckpt_dir=checkpointDir, max_to_keep=1, agent=tf_agent, policy=tf_agent.policy, replay_buffer=replay_buffer, global_step=global_step) train_checkpointer.initialize_or_restore() else: initial_collect_driver.run() _timeCost = (dt.datetime.now() - _startTime).total_seconds() print('Replay Buffer Warm-up Done. (cost {:.3g} hours)'.format(_timeCost / 3600.0)) _startTime = dt.datetime.now() # Training print('Prepare for training ...') collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( env,
def train_eval( root_dir, env_name='HalfCheetah-v2', eval_env_name=None, env_load_fn=suite_mujoco.load, # The SAC paper reported: # Hopper and Cartpole results up to 1000000 iters, # Humanoid results up to 10000000 iters, # Other mujoco tasks up to 3000000 iters. num_iterations=3000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py # HalfCheetah and Ant take 10000 initial collection steps. # Other mujoco tasks take 1000. # Different choices roughly keep the initial episodes about the same. initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] eval_summary_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_env_name = eval_env_name or env_name eval_py_env = env_load_fn(eval_env_name) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy( greedy_policy.GreedyPolicy(tf_agent.policy)) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] collect_policy = tf_agent.collect_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) global_step_val = sess.run(global_step) if global_step_val == 0: # Initial eval of randomly initialized policy metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): total_loss, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_whitelist='position', num_iterations=100000, actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), # Params for collect initial_collect_steps=1, collect_episodes_per_iteration=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=200, batch_size=64, train_sequence_length=10, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, gamma=0.995, reward_scale_factor=1.0, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=10000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_whitelist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_whitelist=[observations_whitelist]) ] else: env_wrappers = [] environment = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(environment) eval_py_env = suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers) actor_net = actor_rnn_network.ActorRnnNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_rnn_network.CriticRnnNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, ) tf_agent = td3_agent.Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, gamma=gamma, reward_scale_factor=reward_scale_factor, debug_summaries=debug_summaries, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # TODO(oars): Refactor drivers to better handle policy states. Remove the # policy reset and passing down an empyt policy state to the driver. collect_policy = tf_agent.collect_policy policy_state = collect_policy.get_initial_state(tf_env.batch_size) initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=initial_collect_steps).run(policy_state=policy_state) policy_state = collect_policy.get_initial_state(tf_env.batch_size) collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run( policy_state=policy_state) # Need extra step to generate transitions of train_sequence_length. # Dataset generates trajectories with shape [BxTx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = iterator.get_next() train_op = tf_agent.train(experience=trajectories) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable(train_op) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, )
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=100000, fc_layer_params=(100, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints, summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, log_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # Note this is a python environment. env = batched_py_environment.BatchedPyEnvironment( [suite_gym.load(env_name)]) eval_py_env = suite_gym.load(env_name) # Convert specs to BoundedTensorSpec. action_spec = tensor_spec.from_spec(env.action_spec()) observation_spec = tensor_spec.from_spec(env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()), tensor_spec.from_spec(env.action_spec()), fc_layer_params=fc_layer_params) # The agent must be in graph. global_step = tf.compat.v1.train.get_or_create_global_step() agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_collect_policy = agent.collect_policy collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy) greedy_policy = py_tf_policy.PyTFPolicy(agent.policy) random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec()) # Python replay buffer. replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer( capacity=replay_buffer_capacity, data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec)) time_step = env.reset() # Initialize the replay buffer with some transitions. We use the random # policy to initialize the replay buffer to make sure we get a good # distribution of actions. for _ in range(initial_collect_steps): time_step = collect_step(env, time_step, random_policy, replay_buffer) # TODO(b/112041045) Use global_step as counter. train_checkpointer = common.Checkpointer(ckpt_dir=train_dir, agent=agent, global_step=global_step) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=agent.policy, global_step=global_step) ds = replay_buffer.as_dataset(sample_batch_size=batch_size, num_steps=n_step_update + 1) ds = ds.prefetch(4) itr = tf.compat.v1.data.make_initializable_iterator(ds) experience = itr.get_next() train_op = common.function(agent.train)(experience) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) with tf.compat.v1.Session() as session: train_checkpointer.initialize_or_restore(session) common.initialize_uninitialized_variables(session) session.run(itr.initializer) # Copy critic network values to the target critic network. session.run(agent.initialize()) train = session.make_callable(train_op) global_step_call = session.make_callable(global_step) session.run(train_summary_writer.init()) session.run(eval_summary_writer.init()) # Compute initial evaluation metrics. global_step_val = global_step_call() metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) timed_at_step = global_step_val collect_time = 0 train_time = 0 steps_per_second_ph = tf.compat.v1.placeholder(tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() for _ in range(collect_steps_per_iteration): time_step = collect_step(env, time_step, collect_policy, replay_buffer) collect_time += time.time() - start_time start_time = time.time() for _ in range(train_steps_per_iteration): loss = train() train_time += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss.loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) session.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info('%.3f steps/sec', steps_per_sec) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, greedy_policy, num_episodes=num_eval_episodes, global_step=global_step_val, log=True, callback=eval_metrics_callback, ) # Reset timing to avoid counting eval time. timed_at_step = global_step_val start_time = time.time()
def train_eval( root_dir, tf_master='', env_name='HalfCheetah-v2', env_load_fn=suite_mujoco.load, random_seed=None, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, # Params for collect num_environment_steps=25000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-3, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=500, policy_checkpoint_interval=500, log_interval=50, summary_interval=50, summaries_flush_secs=1, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ batched_py_metric.BatchedPyMetric( AverageReturnMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), batched_py_metric.BatchedPyMetric( AverageEpisodeLengthMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), ] eval_summary_writer_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) eval_py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers, activation_fn=tf.keras.activations.tanh) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, entropy_regularization=0.0, importance_ratio_clipping=0.2, normalize_observations=False, normalize_rewards=False, use_gae=True, kl_cutoff_factor=0.0, initial_adaptive_kl_beta=0.0, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) environment_steps_metric = tf_metrics.EnvironmentSteps() environment_steps_count = environment_steps_metric.result() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric( batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments), ] # Add to replay buffer and other agent specific observers. replay_buffer_observer = [replay_buffer.add_batch] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_buffer_observer + train_metrics, num_episodes=collect_episodes_per_iteration).run() trajectories = replay_buffer.gather_all() train_op, _ = tf_agent.train(experience=trajectories) with tf.control_dependencies([train_op]): clear_replay_op = replay_buffer.clear() with tf.control_dependencies([clear_replay_op]): train_op = tf.identity(train_op) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics)) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session(tf_master) as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) collect_time = 0 train_time = 0 timed_at_step = sess.run(global_step) steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) while sess.run(environment_steps_count) < num_environment_steps: global_step_val = sess.run(global_step) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op) start_time = time.time() sess.run(collect_op) collect_time += time.time() - start_time start_time = time.time() total_loss, _ = sess.run([train_op, summary_ops]) train_time += time.time() - start_time global_step_val = sess.run(global_step) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) # One final eval before exiting. metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op)
def testLossLearnerDifferentDistStrat(self, create_agent_fn): # Create the strategies used in the test. The second value is the per-core # batch size. bs_multiplier = 4 strategies = { 'default': (tf.distribute.get_strategy(), 4 * bs_multiplier), 'one_device': (tf.distribute.OneDeviceStrategy('/cpu:0'), 4 * bs_multiplier), 'mirrored': (tf.distribute.MirroredStrategy(), 1 * bs_multiplier), } if tf.config.list_logical_devices('TPU'): strategies['TPU'] = (_get_tpu_strategy(), 2 * bs_multiplier) else: logging.info('TPU hardware is not available, TPU strategy test skipped.') learners = { name: self._build_learner_with_strategy(create_agent_fn, strategy, per_core_batch_size) for name, (strategy, per_core_batch_size) in strategies.items() } # Verify that the initial variable values in the learners are the same. default_strat_trainer, _, default_vars, _, _ = learners['default'] for name, (trainer, _, variables, _, _) in learners.items(): if name != 'default': self._assign_variables(default_strat_trainer, trainer) self.assertLen(variables, len(default_vars)) for default_variable, variable in zip(default_vars, variables): self.assertAllEqual(default_variable, variable) # Calculate losses. losses = {} checkpoint_path = {} iterations = 1 optimizer_variables = {} for name, (trainer, _, variables, train_step, _) in learners.items(): old_vars = self.evaluate(variables) loss = trainer.run(iterations=iterations).loss logging.info('Using strategy: %s, the loss is: %s at train step: %s', name, loss, train_step) new_vars = self.evaluate(variables) losses[name] = old_vars, loss, new_vars self.assertNotEmpty(trainer._agent._optimizer.variables()) optimizer_variables[name] = trainer._agent._optimizer.variables() checkpoint_path[name] = trainer._checkpointer.manager.directory for name, path in checkpoint_path.items(): logging.info('Checkpoint dir for learner %s: %s. Content: %s', name, path, tf.io.gfile.listdir(path)) checkpointer = common.Checkpointer(path) # Make sure that the checkpoint file exists, so the learner initialized # using the corresponding root directory will pick up the values in the # checkpoint file. self.assertTrue(checkpointer.checkpoint_exists) # Create a learner using an existing root directory containing the # checkpoint files. strategy, per_core_batch_size = strategies[name] learner_from_checkpoint = self._build_learner_with_strategy( create_agent_fn, strategy, per_core_batch_size, root_dir=os.path.join(path, '..', '..'))[0] # Check if the learner was in fact created based on the an existing # checkpoint. self.assertTrue(learner_from_checkpoint._checkpointer.checkpoint_exists) # Check if the values of the variables of the learner initialized from # checkpoint that are the same as the values were used to write the # checkpoint. original_learner = learners[name][0] self.assertAllClose( learner_from_checkpoint._agent.collect_policy.variables(), original_learner._agent.collect_policy.variables()) self.assertAllClose(learner_from_checkpoint._agent._optimizer.variables(), original_learner._agent._optimizer.variables()) # Verify same dataset across learner calls. for item in tf.data.Dataset.zip(tuple([v[1] for v in learners.values()])): for i in range(1, len(item)): # Compare default strategy obervation to the other datasets, second # index is getting the trajectory from (trajectory, sample_info) tuple. self.assertAllEqual(item[0][0].observation, item[i][0].observation) # Check that the losses are close to each other. _, default_loss, _ = losses['default'] for name, (_, loss, _) in losses.items(): self._compare_losses(loss, default_loss, delta=1.e-2) # Check that the optimizer variables are close to each other. default_optimizer_vars = optimizer_variables['default'] for name, optimizer_vars in optimizer_variables.items(): self.assertAllClose( optimizer_vars, default_optimizer_vars, atol=1.e-2, rtol=1.e-2, msg=('The initial values of the optimizer variables for the strategy ' '{} are significantly different from the initial values of the ' 'default strategy.').format(name)) # Check that the variables changed after calling `learner.run`. for old_vars, _, new_vars in losses.values(): dist_test_utils.check_variables_different(self, old_vars, new_vars)
def train_eval( root_dir, env_name='MultiGrid-Empty-5x5-v0', env_load_fn=multiagent_gym_suite.load, random_seed=0, # Architecture params agent_class=multiagent_ppo.MultiagentPPO, actor_fc_layers=(64, 64), value_fc_layers=(64, 64), lstm_size=(64,), conv_filters=64, conv_kernel=3, direction_fc=5, entropy_regularization=0., use_attention_networks=False, # Specialized agents inactive_agent_ids=tuple(), # Params for collect num_environment_steps=25000000, collect_episodes_per_iteration=30, num_parallel_environments=5, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=2, learning_rate=1e-4, # Params for eval num_eval_episodes=2, eval_interval=5, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=100, log_interval=10, summary_interval=10, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=True, summarize_grads_and_vars=True, eval_metrics_callback=None, reinit_checkpoint_dir=None, debug=True): """A simple train and eval for PPO.""" tf.compat.v1.enable_v2_behavior() if root_dir is None: raise AttributeError('train_eval requires a root_dir.') if debug: logging.info('In debug mode, turning tf_functions off') use_tf_functions = False for a in inactive_agent_ids: logging.info('Fixing and not training agent %d', a) # Load multiagent gym environment and determine number of agents gym_env = env_load_fn(env_name) n_agents = gym_env.n_agents # Set up logging root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ multiagent_metrics.AverageReturnMetric( n_agents, buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) logging.info('Creating %d environments...', num_parallel_environments) wrappers = [] if use_attention_networks: wrappers = [lambda env: utils.LSTMStateWrapper(env, lstm_size=lstm_size)] eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn( env_name, gym_kwargs=dict(seed=random_seed), gym_env_wrappers=wrappers)) # pylint: disable=g-complex-comprehension tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ functools.partial(env_load_fn, environment_name=env_name, gym_env_wrappers=wrappers, gym_kwargs=dict(seed=random_seed * 1234 + i)) for i in range(num_parallel_environments) ])) logging.info('Preparing to train...') environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] bonus_metrics = [ multiagent_metrics.MultiagentScalar( n_agents, name='UnscaledMultiagentBonus', buffer_size=1000), ] train_metrics = step_metrics + [ multiagent_metrics.AverageReturnMetric( n_agents, batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments), ] logging.info('Creating agent...') tf_agent = agent_class( tf_env.time_step_spec(), tf_env.action_spec(), n_agents=n_agents, learning_rate=learning_rate, actor_fc_layers=actor_fc_layers, value_fc_layers=value_fc_layers, lstm_size=lstm_size, conv_filters=conv_filters, conv_kernel=conv_kernel, direction_fc=direction_fc, entropy_regularization=entropy_regularization, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, inactive_agent_ids=inactive_agent_ids) tf_agent.initialize() eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy logging.info('Allocating replay buffer ...') replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) logging.info('RB capacity: %i', replay_buffer.capacity) # If reinit_checkpoint_dir is provided, the last agent in the checkpoint is # reinitialized. The other agents are novices. # Otherwise, all agents are reinitialized from train_dir. if reinit_checkpoint_dir: reinit_checkpointer = common.Checkpointer( ckpt_dir=reinit_checkpoint_dir, agent=tf_agent, ) reinit_checkpointer.initialize_or_restore() temp_dir = os.path.join(train_dir, 'tmp') agent_checkpointer = common.Checkpointer( ckpt_dir=temp_dir, agent=tf_agent.agents[:-1], ) agent_checkpointer.save(global_step=0) tf_agent = agent_class( tf_env.time_step_spec(), tf_env.action_spec(), n_agents=n_agents, learning_rate=learning_rate, actor_fc_layers=actor_fc_layers, value_fc_layers=value_fc_layers, lstm_size=lstm_size, conv_filters=conv_filters, conv_kernel=conv_kernel, direction_fc=direction_fc, entropy_regularization=entropy_regularization, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, inactive_agent_ids=inactive_agent_ids, non_learning_agents=list(range(n_agents - 1))) agent_checkpointer = common.Checkpointer( ckpt_dir=temp_dir, agent=tf_agent.agents[:-1]) agent_checkpointer.initialize_or_restore() tf.io.gfile.rmtree(temp_dir) eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=multiagent_metrics.MultiagentMetricsGroup( train_metrics + bonus_metrics, 'train_metrics')) if not reinit_checkpoint_dir: train_checkpointer.initialize_or_restore() logging.info('Successfully initialized train checkpointer') policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) collect_policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=collect_policy, global_step=global_step) collect_saved_model = policy_saver.PolicySaver( collect_policy, train_step=global_step) logging.info('Successfully initialized policy saver.') print('Using TFDriver') if use_attention_networks: collect_driver = utils.StateTFDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, max_episodes=collect_episodes_per_iteration, disable_tf_function=not use_tf_functions) else: collect_driver = tf_driver.TFDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, max_episodes=collect_episodes_per_iteration, disable_tf_function=not use_tf_functions) def train_step(): trajectories = replay_buffer.gather_all() return tf_agent.train(experience=trajectories) if use_tf_functions: tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 # Save operative config as late as possible to include used configurables. if global_step.numpy() == 0: config_filename = os.path.join( train_dir, 'operative_config-{}.gin'.format(global_step.numpy())) with tf.io.gfile.GFile(config_filename, 'wb') as f: f.write(gin.operative_config_str()) total_episodes = 0 logging.info('Commencing train loop!') while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() # Evaluation if global_step_val % eval_interval == 0: if debug: logging.info('Performing evaluation at step %d', global_step_val) results = multiagent_metrics.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', use_function=use_tf_functions, use_attention_networks=use_attention_networks ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) multiagent_metrics.log_metrics(eval_metrics) # Collect data if debug: logging.info('Collecting at step %d', global_step_val) start_time = time.time() time_step = tf_env.reset() policy_state = collect_policy.get_initial_state(tf_env.batch_size) if use_attention_networks: # Attention networks require previous policy state to compute attention # weights. time_step.observation['policy_state'] = ( policy_state['actor_network_state'][0], policy_state['actor_network_state'][1]) collect_driver.run(time_step, policy_state) collect_time += time.time() - start_time total_episodes += collect_episodes_per_iteration if debug: logging.info('Have collected a total of %d episodes', total_episodes) # Train if debug: logging.info('Training at step %d', global_step_val) start_time = time.time() total_loss, extra_loss = train_step() replay_buffer.clear() train_time += time.time() - start_time # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: logging.info('Loss diverged for too many timesteps, breaking...') break else: loss_divergence_counter = 0 for train_metric in train_metrics + bonus_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, total loss = %f', global_step_val, total_loss) for a in range(n_agents): if not inactive_agent_ids or a not in inactive_agent_ids: logging.info('Loss for agent %d = %f', a, extra_loss[a].loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = %.3f, train_time = %.3f', collect_time, train_time) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) collect_policy_checkpointer.save(global_step=global_step_val) collect_saved_model_path = os.path.join( saved_model_dir, 'collect_policy_' + ('%d' % global_step_val).zfill(9)) collect_saved_model.save(collect_saved_model_path) # One final eval before exiting. results = multiagent_metrics.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', use_function=use_tf_functions, use_attention_networks=use_attention_networks ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) multiagent_metrics.log_metrics(eval_metrics)
observers=[replay_buffer.add_batch], num_steps=STEPS_PER_ITER) # Wrap the run function in a TF graph driver.run = common.function(driver.run) # Create driver for the random policy random_driver = DynamicStepDriver(env=train_env, policy=random_policy, observers=[replay_buffer.add_batch], num_steps=STEPS_PER_ITER) # Wrap the run function in a TF graph random_driver.run = common.function(random_driver.run) # Create a checkpointer checkpointer = common.Checkpointer(ckpt_dir=os.path.relpath('checkpoint'), max_to_keep=1, agent=agent, policy=agent.policy, replay_buffer=replay_buffer, global_step=global_step) checkpointer.initialize_or_restore() global_step = tf.compat.v1.train.get_global_step() # Create a policy saver policy_saver = PolicySaver(agent.policy) # Main training loop time_step, policy_state = None, None for it in range(N_ITERATIONS): if COLLECT_RANDOM: print('Running random driver...') time_step, policy_state = random_driver.run(time_step, policy_state) print('Running agent driver...')
def __init__( self, root_dir, env_name, num_iterations=200, max_episode_frames=108000, # ALE frames terminal_on_life_loss=False, conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)), fc_layer_params=(512, ), # Params for collect initial_collect_steps=80000, # ALE frames epsilon_greedy=0.01, epsilon_decay_period=1000000, # ALE frames replay_buffer_capacity=1000000, # Params for train train_steps_per_iteration=1000000, # ALE frames update_period=16, # ALE frames target_update_tau=1.0, target_update_period=32000, # ALE frames batch_size=32, learning_rate=2.5e-4, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval do_eval=True, eval_steps_per_iteration=500000, # ALE frames eval_epsilon_greedy=0.001, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple Atari train and eval for DQN. Args: root_dir: Directory to write log files to. env_name: Fully-qualified name of the Atari environment (i.e. Pong-v0). num_iterations: Number of train/eval iterations to run. max_episode_frames: Maximum length of a single episode, in ALE frames. terminal_on_life_loss: Whether to simulate an episode termination when a life is lost. conv_layer_params: Params for convolutional layers of QNetwork. fc_layer_params: Params for fully connected layers of QNetwork. initial_collect_steps: Number of frames to ALE frames to process before beginning to train. Since this is in ALE frames, there will be initial_collect_steps/4 items in the RB when training starts. epsilon_greedy: Final epsilon value to decay to for training. epsilon_decay_period: Period over which to decay epsilon, from 1.0 to epsilon_greedy (defined above). replay_buffer_capacity: Maximum number of items to store in the RB. train_steps_per_iteration: Number of ALE frames to run through for each iteration of training. update_period: Run a train operation every update_period ALE frames. target_update_tau: Coeffecient for soft target network updates (1.0 == hard updates). target_update_period: Period, in ALE frames, to copy the live network to the target network. batch_size: Number of frames to include in each training batch. learning_rate: RMS optimizer learning rate. gamma: Discount for future rewards. reward_scale_factor: Scaling factor for rewards. gradient_clipping: Norm length to clip gradients. do_eval: If True, run an eval every iteration. If False, skip eval. eval_steps_per_iteration: Number of ALE frames to run through for each iteration of training. eval_epsilon_greedy: Epsilon value to use for the evaluation policy (0 == totally greedy policy). log_interval: Log stats to the terminal every log_interval training steps. summary_interval: Write TF summaries every summary_interval training steps. summaries_flush_secs: Flush summaries to disk every summaries_flush_secs seconds. debug_summaries: If True, write additional summaries for debugging (see dqn_agent for which summaries are written). summarize_grads_and_vars: Include gradients in summaries. eval_metrics_callback: A callback function that takes (metric_dict, global_step) as parameters. Called after every eval with the results of the evaluation. """ self._update_period = update_period / ATARI_FRAME_SKIP self._train_steps_per_iteration = (train_steps_per_iteration / ATARI_FRAME_SKIP) self._do_eval = do_eval self._eval_steps_per_iteration = eval_steps_per_iteration / ATARI_FRAME_SKIP self._eval_epsilon_greedy = eval_epsilon_greedy self._initial_collect_steps = initial_collect_steps / ATARI_FRAME_SKIP self._summary_interval = summary_interval self._num_iterations = num_iterations self._log_interval = log_interval self._eval_metrics_callback = eval_metrics_callback with gin.unlock_config(): gin.bind_parameter('AtariPreprocessing.terminal_on_life_loss', terminal_on_life_loss) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() self._train_summary_writer = train_summary_writer self._eval_summary_writer = None if self._do_eval: self._eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) self._eval_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if(lambda: tf.math.equal( self._global_step % self._summary_interval, 0)): self._env = suite_atari.load( env_name, max_episode_steps=max_episode_frames / ATARI_FRAME_SKIP, gym_env_wrappers=suite_atari. DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) self._env = batched_py_environment.BatchedPyEnvironment( [self._env]) observation_spec = tensor_spec.from_spec( self._env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.from_spec(self._env.action_spec()) with tf.device('/cpu:0'): epsilon = tf.compat.v1.train.polynomial_decay( 1.0, self._global_step, epsilon_decay_period / ATARI_FRAME_SKIP / self._update_period, end_learning_rate=epsilon_greedy) with tf.device('/gpu:0'): optimizer = tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.0, epsilon=0.00001, centered=True) q_net = AtariQNetwork(observation_spec, action_spec, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params) tf_agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_net, optimizer=optimizer, epsilon_greedy=epsilon, target_update_tau=target_update_tau, target_update_period=(target_update_period / ATARI_FRAME_SKIP / self._update_period), td_errors_loss_fn=dqn_agent.element_wise_huber_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=self._global_step) self._collect_policy = py_tf_policy.PyTFPolicy( tf_agent.collect_policy) if self._do_eval: self._eval_policy = py_tf_policy.PyTFPolicy( epsilon_greedy_policy.EpsilonGreedyPolicy( policy=tf_agent.policy, epsilon=self._eval_epsilon_greedy)) py_observation_spec = self._env.observation_spec() py_time_step_spec = ts.time_step_spec(py_observation_spec) py_action_spec = policy_step.PolicyStep( self._env.action_spec()) data_spec = trajectory.from_transition(py_time_step_spec, py_action_spec, py_time_step_spec) self._replay_buffer = ( py_hashed_replay_buffer.PyHashedReplayBuffer( data_spec=data_spec, capacity=replay_buffer_capacity)) with tf.device('/cpu:0'): ds = self._replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(4) ds = ds.apply( tf.data.experimental.prefetch_to_device('/gpu:0')) with tf.device('/gpu:0'): self._ds_itr = tf.compat.v1.data.make_one_shot_iterator(ds) experience = self._ds_itr.get_next() self._train_op = tf_agent.train(experience) self._env_steps_metric = py_metrics.EnvironmentSteps() self._step_metrics = [ py_metrics.NumberOfEpisodes(), self._env_steps_metric, ] self._train_metrics = self._step_metrics + [ py_metrics.AverageReturnMetric(buffer_size=10), py_metrics.AverageEpisodeLengthMetric(buffer_size=10), ] # The _train_phase_metrics average over an entire train iteration, # rather than the rolling average of the last 10 episodes. self._train_phase_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._iteration_metric = py_metrics.CounterMetric( name='Iteration') # Summaries written from python should run every time they are # generated. with tf.compat.v2.summary.record_if(True): self._steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') self._steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=self._steps_per_second_ph, step=self._global_step) for metric in self._train_metrics: metric.tf_summaries(train_step=self._global_step, step_metrics=self._step_metrics) for metric in self._train_phase_metrics: metric.tf_summaries( train_step=self._global_step, step_metrics=(self._iteration_metric, )) self._iteration_metric.tf_summaries( train_step=self._global_step) if self._do_eval: with self._eval_summary_writer.as_default(): for metric in self._eval_metrics: metric.tf_summaries( train_step=self._global_step, step_metrics=(self._iteration_metric, )) self._train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=self._global_step, optimizer=optimizer, metrics=metric_utils.MetricsGroup( self._train_metrics + self._train_phase_metrics + [self._iteration_metric], 'train_metrics')) self._policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=self._global_step) self._rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=self._replay_buffer) self._init_agent_op = tf_agent.initialize()
def main(_): # setting up start_time = time.time() tf.compat.v1.enable_resource_variables() tf.compat.v1.disable_eager_execution() logging.set_verbosity(logging.INFO) global observation_omit_size, goal_coord, sample_count, iter_count, episode_size_buffer, episode_return_buffer root_dir = os.path.abspath(os.path.expanduser(FLAGS.logdir)) if not tf.io.gfile.exists(root_dir): tf.io.gfile.makedirs(root_dir) log_dir = os.path.join(root_dir, FLAGS.environment) if not tf.io.gfile.exists(log_dir): tf.io.gfile.makedirs(log_dir) save_dir = os.path.join(log_dir, 'models') if not tf.io.gfile.exists(save_dir): tf.io.gfile.makedirs(save_dir) print('directory for recording experiment data:', log_dir) # in case training is paused and resumed, so can be restored try: sample_count = np.load(os.path.join(log_dir, 'sample_count.npy')).tolist() iter_count = np.load(os.path.join(log_dir, 'iter_count.npy')).tolist() episode_size_buffer = np.load( os.path.join(log_dir, 'episode_size_buffer.npy')).tolist() episode_return_buffer = np.load( os.path.join(log_dir, 'episode_return_buffer.npy')).tolist() except: sample_count = 0 iter_count = 0 episode_size_buffer = [] episode_return_buffer = [] train_summary_writer = tf.compat.v2.summary.create_file_writer( os.path.join(log_dir, 'train', 'in_graph_data'), flush_millis=10 * 1000) train_summary_writer.set_as_default() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if(True): # environment related stuff env = do.get_environment(env_name=FLAGS.environment) py_env = wrap_env(skill_wrapper.SkillWrapper( env, num_latent_skills=FLAGS.num_skills, skill_type=FLAGS.skill_type, preset_skill=None, min_steps_before_resample=FLAGS.min_steps_before_resample, resample_prob=FLAGS.resample_prob), max_episode_steps=FLAGS.max_env_steps) # all specifications required for all networks and agents py_action_spec = py_env.action_spec() tf_action_spec = tensor_spec.from_spec( py_action_spec) # policy, critic action spec env_obs_spec = py_env.observation_spec() py_env_time_step_spec = ts.time_step_spec( env_obs_spec) # replay buffer time_step spec if observation_omit_size > 0: agent_obs_spec = array_spec.BoundedArraySpec( (env_obs_spec.shape[0] - observation_omit_size, ), env_obs_spec.dtype, minimum=env_obs_spec.minimum, maximum=env_obs_spec.maximum, name=env_obs_spec.name) # policy, critic observation spec else: agent_obs_spec = env_obs_spec py_agent_time_step_spec = ts.time_step_spec( agent_obs_spec) # policy, critic time_step spec tf_agent_time_step_spec = tensor_spec.from_spec( py_agent_time_step_spec) if not FLAGS.reduced_observation: skill_dynamics_observation_size = ( py_env_time_step_spec.observation.shape[0] - FLAGS.num_skills) else: skill_dynamics_observation_size = FLAGS.reduced_observation # TODO(architsh): Shift co-ordinate hiding to actor_net and critic_net (good for futher image based processing as well) actor_net = actor_distribution_network.ActorDistributionNetwork( tf_agent_time_step_spec.observation, tf_action_spec, fc_layer_params=(FLAGS.hidden_layer_size, ) * 2, continuous_projection_net=do._normal_projection_net) critic_net = critic_network.CriticNetwork( (tf_agent_time_step_spec.observation, tf_action_spec), observation_fc_layer_params=None, action_fc_layer_params=None, joint_fc_layer_params=(FLAGS.hidden_layer_size, ) * 2) if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type and FLAGS.is_clip_eps > 1.0: reweigh_batches_flag = True else: reweigh_batches_flag = False agent = dads_agent.DADSAgent( # DADS parameters save_dir, skill_dynamics_observation_size, observation_modify_fn=do.process_observation, restrict_input_size=observation_omit_size, latent_size=FLAGS.num_skills, latent_prior=FLAGS.skill_type, prior_samples=FLAGS.random_skills, fc_layer_params=(FLAGS.hidden_layer_size, ) * 2, normalize_observations=FLAGS.normalize_data, network_type=FLAGS.graph_type, num_mixture_components=FLAGS.num_components, fix_variance=FLAGS.fix_variance, reweigh_batches=reweigh_batches_flag, skill_dynamics_learning_rate=FLAGS.skill_dynamics_lr, # SAC parameters time_step_spec=tf_agent_time_step_spec, action_spec=tf_action_spec, actor_network=actor_net, critic_network=critic_net, target_update_tau=0.005, target_update_period=1, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=FLAGS.agent_lr), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=FLAGS.agent_lr), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=FLAGS.agent_lr), td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=FLAGS.agent_gamma, reward_scale_factor=1. / (FLAGS.agent_entropy + 1e-12), gradient_clipping=None, debug_summaries=FLAGS.debug, train_step_counter=global_step) # evaluation policy eval_policy = py_tf_policy.PyTFPolicy(agent.policy) # collection policy if FLAGS.collect_policy == 'default': collect_policy = py_tf_policy.PyTFPolicy(agent.collect_policy) elif FLAGS.collect_policy == 'ou_noise': collect_policy = py_tf_policy.PyTFPolicy( ou_noise_policy.OUNoisePolicy(agent.collect_policy, ou_stddev=0.2, ou_damping=0.15)) # relabelling policy deals with batches of data, unlike collect and eval relabel_policy = py_tf_policy.PyTFPolicy(agent.collect_policy) # constructing a replay buffer, need a python spec policy_step_spec = policy_step.PolicyStep(action=py_action_spec, state=(), info=()) if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type and FLAGS.is_clip_eps > 1.0: policy_step_spec = policy_step_spec._replace( info=policy_step.set_log_probability( policy_step_spec.info, array_spec.ArraySpec( shape=(), dtype=np.float32, name='action_log_prob'))) trajectory_spec = from_transition(py_env_time_step_spec, policy_step_spec, py_env_time_step_spec) capacity = FLAGS.replay_buffer_capacity # for all the data collected rbuffer = py_uniform_replay_buffer.PyUniformReplayBuffer( capacity=capacity, data_spec=trajectory_spec) if FLAGS.train_skill_dynamics_on_policy: # for on-policy data (if something special is required) on_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer( capacity=FLAGS.initial_collect_steps + FLAGS.collect_steps + 10, data_spec=trajectory_spec) # insert experience manually with relabelled rewards and skills agent.build_agent_graph() agent.build_skill_dynamics_graph() agent.create_savers() # saving this way requires the saver to be out the object train_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( save_dir, 'agent'), agent=agent, global_step=global_step) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( save_dir, 'policy'), policy=agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( save_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=rbuffer) setup_time = time.time() - start_time print('Setup time:', setup_time) with tf.compat.v1.Session().as_default() as sess: eval_policy.session = sess eval_policy.initialize(None) eval_policy.restore(os.path.join(FLAGS.logdir, 'models', 'policy')) plotdir = os.path.join(FLAGS.logdir, "plots") if not os.path.exists(plotdir): os.mkdir(plotdir) do.FLAGS = FLAGS do.eval_loop(eval_dir=plotdir, eval_policy=eval_policy, plot_name="plot")
def train_eval( root_dir, env_name='HalfCheetah-v2', env_load_fn=suite_mujoco.load, random_seed=None, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, # Params for collect num_environment_steps=25000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-3, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=500, policy_checkpoint_interval=500, log_interval=50, summary_interval=50, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=False, summarize_grads_and_vars=False): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers, activation_fn=tf.keras.activations.tanh) tf_agent = ppo_clip_agent.PPOClipAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, entropy_regularization=0.0, importance_ratio_clipping=0.2, normalize_observations=False, normalize_rewards=False, use_gae=True, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric( batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) train_checkpointer.initialize_or_restore() collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) def train_step(): trajectories = replay_buffer.gather_all() return tf_agent.train(experience=trajectories) if use_tf_functions: # TODO(b/123828980): Enable once the cause for slowdown was identified. collect_driver.run = common.function(collect_driver.run, autograph=False) tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) start_time = time.time() collect_driver.run() collect_time += time.time() - start_time start_time = time.time() total_loss, _ = train_step() replay_buffer.clear() train_time += time.time() - start_time for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = %.3f, train_time = %.3f', collect_time, train_time) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) timed_at_step = global_step_val collect_time = 0 train_time = 0 # One final eval before exiting. metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', )
def train_eval( root_dir, env_name='gym_solventx-v0', num_iterations=100000, train_sequence_length=1, # Params for QNetwork fc_layer_params=(100, ), # Params for QRnnNetwork input_fc_layer_params=(50, ), lstm_size=(20, ), output_fc_layer_params=(20, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, epsilon_greedy=0.1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=1000, # Params for checkpoints train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, # Params for summaries and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DQN.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): gym_env = gym.make(env_name, config_file=config_file) py_env = suite_gym.wrap_env(gym_env, max_episode_steps=100) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_gym_env = gym.make(env_name, config_file=config_file) eval_py_env = suite_gym.wrap_env(eval_gym_env, max_episode_steps=100) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) #tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) #eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name), config_file=config_file) if train_sequence_length != 1 and n_step_update != 1: raise NotImplementedError( 'train_eval does not currently support n-step updates with stateful ' 'networks (i.e., RNNs)') if train_sequence_length > 1: q_net = q_rnn_network.QRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=input_fc_layer_params, lstm_size=lstm_size, output_fc_layer_params=output_fc_layer_params) else: q_net = q_network.QNetwork(tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=fc_layer_params) train_sequence_length = n_step_update # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839 tf_agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, epsilon_greedy=epsilon_greedy, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if use_tf_functions: # To speed up collect use common.function. collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step.numpy()) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step.numpy()).zfill(9)) saved_model.save(saved_model_path) if global_step.numpy() % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, env_name='HalfCheetah-v2', eval_env_name=None, env_load_fn=suite_mujoco.load, # The SAC paper reported: # Hopper and Cartpole results up to 1000000 iters, # Humanoid results up to 10000000 iters, # Other mujoco tasks up to 3000000 iters. num_iterations=3000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py # HalfCheetah and Ant take 10000 initial collection steps. # Other mujoco tasks take 1000. # Different choices roughly keep the initial episodes about the same. initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=50000, policy_checkpoint_interval=50000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(eval_env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if replay_buffer.num_frames() == 0: # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps ' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) global_step_val = global_step.numpy() while global_step_val < num_iterations: start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step_val % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def train( root_dir, agent, environment, training_loops, steps_per_loop=1, additional_metrics=(), # Params for checkpoints, summaries, and logging train_checkpoint_interval=10, policy_checkpoint_interval=10, log_interval=10, summary_interval=10): """A training driver.""" if not common.resource_variables_enabled(): raise RuntimeError(common.MISSING_RESOURCE_VARIABLES_ERROR) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') train_summary_writer = tf.compat.v2.summary.create_file_writer(train_dir) train_summary_writer.set_as_default() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(batch_size=environment.batch_size), tf_metrics.AverageEpisodeLengthMetric( batch_size=environment.batch_size), ] + list(additional_metrics) # Add to replay buffer and other agent specific observers. replay_buffer = build_replay_buffer(agent, environment.batch_size, steps_per_loop) agent_observers = [replay_buffer.add_batch] + train_metrics driver = dynamic_step_driver.DynamicStepDriver( env=environment, policy=agent.policy, num_steps=steps_per_loop * environment.batch_size, observers=agent_observers) collect_op, _ = driver.run() batch_size = driver.env.batch_size dataset = replay_buffer.as_dataset(sample_batch_size=batch_size, num_steps=steps_per_loop, single_deterministic_pass=True) trajectories, unused_info = tf.data.experimental.get_single_element( dataset) train_op = agent.train(experience=trajectories) clear_replay_op = replay_buffer.clear() train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, max_to_keep=1, agent=agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), max_to_keep=None, policy=agent.policy, global_step=global_step) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) init_agent_op = agent.initialize() config_saver = utils.GinConfigSaverHook(train_dir, summarize_config=True) config_saver.begin() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) common.initialize_uninitialized_variables(sess) config_saver.after_create_session(sess) global_step_call = sess.make_callable(global_step) global_step_val = global_step_call() sess.run(train_summary_writer.init()) sess.run(collect_op) if global_step_val == 0: # Save an initial checkpoint so the evaluator runs for global_step=0. policy_checkpointer.save(global_step=global_step_val) sess.run(init_agent_op) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) clear_replay_call = sess.make_callable(clear_replay_op) timed_at_step = global_step_val time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(training_loops): # Collect and train. start_time = time.time() collect_call() total_loss, _ = train_step_call() clear_replay_call() global_step_val = global_step_call() time_acc += time.time() - start_time total_loss = total_loss.loss if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val)
driver = dynamic_step_driver.DynamicStepDriver( train_env, agent.collect_policy, observers=[replay_buffer.add_batch, metric], num_steps=collect_steps_per_iteration) # Initial data collection driver.run() # Dataset generates trajectories with shape [BxTx...] where # T = n_step_update + 1. dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2, single_deterministic_pass=False).prefetch(3) iterator = iter(dataset) train_checkpointer = common.Checkpointer(ckpt_dir=CHECKPOINT_DIR, max_to_keep=1, agent=agent, policy=agent.policy, replay_buffer=replay_buffer, global_step=global_step) # train the agent # (Optional) Optimize by wrapping some of the code in a graph using TF function. agent.train = common.function(agent.train) def train_one_iteration(): # Collect a few steps using collect_policy and save to the replay buffer. driver.run() # Sample a batch of data from the buffer and update the agent's network. experience, unused_info = next(iterator) # print('#' * 80)
def train_eval( root_dir, env_name='CartPole-v0', num_iterations=1000, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(100, ), value_net_fc_layers=(100, ), use_value_network=False, # Params for collect collect_episodes_per_iteration=2, replay_buffer_capacity=2000, # Params for train learning_rate=1e-3, gamma=0.9, gradient_clipping=None, normalize_returns=True, value_estimation_loss_coef=0.2, # Params for eval num_eval_episodes=10, eval_interval=100, # Params for checkpoints, summaries, and logging train_checkpoint_interval=100, policy_checkpoint_interval=100, rb_checkpoint_interval=200, log_interval=100, summary_interval=100, summaries_flush_secs=1, debug_summaries=True, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for Reinforce.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): eval_py_env = suite_gym.load(env_name) tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) # TODO(b/127870767): Handle distributions without gin. actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers) if use_value_network: value_net = value_network.ValueNetwork( tf_env.time_step_spec().observation, fc_layer_params=value_net_fc_layers) tf_agent = reinforce_agent.ReinforceAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, value_network=value_net if use_value_network else None, value_estimation_loss_coef=value_estimation_loss_coef, gamma=gamma, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), normalize_returns=normalize_returns, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration).run() experience = replay_buffer.gather_all() train_op = tf_agent.train(experience) clear_rb_op = replay_buffer.clear() train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # TODO(b/126239733): Remove once Periodically can be saved. common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Compute evaluation metrics. global_step_call = sess.make_callable(global_step) global_step_val = global_step_call() metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) clear_rb_call = sess.make_callable(clear_rb_op) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() total_loss, _ = train_step_call() clear_rb_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
dataset = _replay_buffer.as_dataset(num_parallel_calls=30, sample_batch_size=_batch_size, num_steps=2).prefetch(30) _agent.train = common.function(_agent.train) _agent.train_step_counter.assign(0) print('initial collect...') avg_return = compute_avg_return(_eval_env, _agent.policy, _num_eval_episodes) returns = [avg_return] iterator = iter(dataset) train_checkpointer = common.Checkpointer(ckpt_dir=_checkpoint_policy_dir, max_to_keep=1, agent=_agent, policy=_agent.policy, replay_buffer=_replay_buffer, global_step=_train_step_counter) tf_policy_saver = policy_saver.PolicySaver(_agent.policy) restore_network = True if restore_network: train_checkpointer.initialize_or_restore() #_train_env.pyenv._envs[0].set_rendering(enabled=False) while True: print('Collecting...') for _ in tqdm(range(_num_train_episodes)):
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=3000000, actor_fc_layers=(), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, dual_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=50000, policy_checkpoint_interval=50000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None, latent_dim=10, log_prob_reward_scale=0.0, predictor_updates_encoder=False, predict_prior=True, use_recurrent_actor=False, rnn_sequence_length=20, clip_max_stddev=10.0, clip_min_stddev=0.1, clip_mean=30.0, predictor_num_layers=2, use_identity_encoder=False, identity_encoder_single_stddev=False, kl_constraint=1.0, eval_dropout=(), use_residual_predictor=True, gym_kwargs=None, predict_prior_std=True, random_seed=0, ): """A simple train and eval for SAC.""" np.random.seed(random_seed) tf.random.set_seed(random_seed) if use_recurrent_actor: batch_size = batch_size // rnn_sequence_length root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): _build_env = functools.partial( suite_gym.load, environment_name=env_name, # pylint: disable=invalid-name gym_env_wrappers=(), gym_kwargs=gym_kwargs) tf_env = tf_py_environment.TFPyEnvironment(_build_env()) eval_vec = [] # (name, env, metrics) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes) ] eval_tf_env = tf_py_environment.TFPyEnvironment(_build_env()) name = '' eval_vec.append((name, eval_tf_env, eval_metrics)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() if latent_dim == 'obs': latent_dim = observation_spec.shape[0] def _activation(t): t1, t2 = tf.split(t, 2, axis=1) low = -np.inf if clip_mean is None else -clip_mean high = np.inf if clip_mean is None else clip_mean t1 = rpc_utils.squash_to_range(t1, low, high) if clip_min_stddev is None: low = -np.inf else: low = tf.math.log(tf.exp(clip_min_stddev) - 1.0) if clip_max_stddev is None: high = np.inf else: high = tf.math.log(tf.exp(clip_max_stddev) - 1.0) t2 = rpc_utils.squash_to_range(t2, low, high) return tf.concat([t1, t2], axis=1) if use_identity_encoder: assert latent_dim == observation_spec.shape[0] obs_input = tf.keras.layers.Input(observation_spec.shape) zeros = 0.0 * obs_input[:, :1] stddev_dim = 1 if identity_encoder_single_stddev else latent_dim pre_stddev = tf.keras.layers.Dense(stddev_dim, activation=None)(zeros) ones = zeros + tf.ones((1, latent_dim)) pre_stddev = pre_stddev * ones # Multiply to broadcast to latent_dim. pre_mean_stddev = tf.concat([obs_input, pre_stddev], axis=1) output = tfp.layers.IndependentNormal(latent_dim)(pre_mean_stddev) encoder_net = tf.keras.Model(inputs=obs_input, outputs=output) else: encoder_net = tf.keras.Sequential([ tf.keras.layers.Dense(256, activation='relu'), tf.keras.layers.Dense(256, activation='relu'), tf.keras.layers.Dense( tfp.layers.IndependentNormal.params_size(latent_dim), activation=_activation, kernel_initializer='glorot_uniform'), tfp.layers.IndependentNormal(latent_dim), ]) # Build the predictor net obs_input = tf.keras.layers.Input(observation_spec.shape) action_input = tf.keras.layers.Input(action_spec.shape) class ConstantIndependentNormal(tfp.layers.IndependentNormal): """A keras layer that always returns N(0, 1) distribution.""" def call(self, inputs): loc_scale = tf.concat([ tf.zeros((latent_dim, )), tf.fill((latent_dim, ), tf.math.log(tf.exp(1.0) - 1)) ], axis=0) # Multiple by [B x 1] tensor to broadcast batch dimension. loc_scale = loc_scale * tf.ones_like(inputs[:, :1]) return super(ConstantIndependentNormal, self).call(loc_scale) if predict_prior: z = encoder_net(obs_input) if not predictor_updates_encoder: z = tf.stop_gradient(z) za = tf.concat([z, action_input], axis=1) if use_residual_predictor: za_input = tf.keras.layers.Input(za.shape[1]) loc_scale = tf.keras.Sequential( predictor_num_layers * [tf.keras.layers.Dense(256, activation='relu')] + [ # pylint: disable=line-too-long tf.keras.layers.Dense(tfp.layers.IndependentNormal. params_size(latent_dim), activation=_activation, kernel_initializer='zeros'), ])(za_input) if predict_prior_std: combined_loc_scale = tf.concat([ loc_scale[:, :latent_dim] + za_input[:, :latent_dim], loc_scale[:, latent_dim:] ], axis=1) else: # Note that softplus(log(e - 1)) = 1. combined_loc_scale = tf.concat([ loc_scale[:, :latent_dim] + za_input[:, :latent_dim], tf.math.log(np.e - 1) * tf.ones_like(loc_scale[:, latent_dim:]) ], axis=1) dist = tfp.layers.IndependentNormal(latent_dim)( combined_loc_scale) output = tf.keras.Model(inputs=za_input, outputs=dist)(za) else: assert predict_prior_std output = tf.keras.Sequential( predictor_num_layers * [tf.keras.layers.Dense(256, activation='relu')] + # pylint: disable=line-too-long [ tf.keras.layers.Dense(tfp.layers.IndependentNormal. params_size(latent_dim), activation=_activation, kernel_initializer='zeros'), tfp.layers.IndependentNormal(latent_dim), ])(za) else: # scale is chosen by inverting the softplus function to equal 1. if len(obs_input.shape) > 2: input_reshaped = tf.reshape( obs_input, [-1, tf.math.reduce_prod(obs_input.shape[1:])]) # Multiply by [B x 1] tensor to broadcast batch dimension. za = tf.zeros(latent_dim + action_spec.shape[0], ) * tf.ones_like(input_reshaped[:, :1]) # pylint: disable=line-too-long else: # Multiple by [B x 1] tensor to broadcast batch dimension. za = tf.zeros(latent_dim + action_spec.shape[0], ) * tf.ones_like(obs_input[:, :1]) # pylint: disable=line-too-long output = tf.keras.Sequential([ ConstantIndependentNormal(latent_dim), ])(za) predictor_net = tf.keras.Model(inputs=(obs_input, action_input), outputs=output) if use_recurrent_actor: ActorClass = rpc_utils.RecurrentActorNet # pylint: disable=invalid-name else: ActorClass = rpc_utils.ActorNet # pylint: disable=invalid-name actor_net = ActorClass(input_tensor_spec=observation_spec, output_tensor_spec=action_spec, encoder=encoder_net, predictor=predictor_net, fc_layers=actor_fc_layers) critic_net = rpc_utils.CriticNet( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') critic_net_2 = None target_critic_net_1 = None target_critic_net_2 = None tf_agent = rpc_agent.RpAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, critic_network_2=critic_net_2, target_critic_network=target_critic_net_1, target_critic_network_2=target_critic_net_2, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) dual_optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=dual_learning_rate) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] kl_metric = rpc_utils.AverageKLMetric(encoder=encoder_net, predictor=predictor_net, batch_size=tf_env.batch_size) eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy checkpoint_items = { 'ckpt_dir': train_dir, 'agent': tf_agent, 'global_step': global_step, 'metrics': metric_utils.MetricsGroup(train_metrics, 'train_metrics'), 'dual_optimizer': dual_optimizer, } train_checkpointer = common.Checkpointer(**checkpoint_items) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps, transition_observers=[kl_metric]) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration, transition_observers=[kl_metric]) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if replay_buffer.num_frames() == 0: # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps ' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() for name, eval_tf_env, eval_metrics in eval_vec: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics-%s' % name, ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics, prefix=name) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 train_time_acc = 0 env_time_acc = 0 if use_recurrent_actor: # default from sac/train_eval_rnn.py num_steps = rnn_sequence_length + 1 def _filter_invalid_transition(trajectories, unused_arg1): return tf.reduce_all(~trajectories.is_boundary()[:-1]) tf_agent._as_transition = data_converter.AsTransition( # pylint: disable=protected-access tf_agent.data_context, squeeze_time_dim=False) else: num_steps = 2 def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=num_steps).unbatch().filter(_filter_invalid_transition) dataset = dataset.batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) @tf.function def train_step(): experience, _ = next(iterator) prior = predictor_net( (experience.observation[:, 0], experience.action[:, 0]), training=False) z_next = encoder_net(experience.observation[:, 1], training=False) # predictor_kl is a vector of size batch_size. predictor_kl = tfp.distributions.kl_divergence(z_next, prior) with tf.GradientTape() as tape: tape.watch(actor_net._log_kl_coefficient) # pylint: disable=protected-access dual_loss = -1.0 * actor_net._log_kl_coefficient * ( # pylint: disable=protected-access tf.stop_gradient(tf.reduce_mean(predictor_kl)) - kl_constraint) dual_grads = tape.gradient(dual_loss, [actor_net._log_kl_coefficient]) # pylint: disable=protected-access grads_and_vars = list( zip(dual_grads, [actor_net._log_kl_coefficient])) # pylint: disable=protected-access dual_optimizer.apply_gradients(grads_and_vars) # Clip the dual variable so exp(log_kl_coef) <= 1e6. log_kl_coef = tf.clip_by_value( actor_net._log_kl_coefficient, # pylint: disable=protected-access -1.0 * np.log(1e6), np.log(1e6)) actor_net._log_kl_coefficient.assign(log_kl_coef) # pylint: disable=protected-access with tf.name_scope('dual_loss'): tf.compat.v2.summary.scalar(name='dual_loss', data=tf.reduce_mean(dual_loss), step=global_step) tf.compat.v2.summary.scalar( name='log_kl_coefficient', data=actor_net._log_kl_coefficient, # pylint: disable=protected-access step=global_step) z_entropy = z_next.entropy() log_prob = prior.log_prob(z_next.sample()) with tf.name_scope('rp-metrics'): common.generate_tensor_summaries('predictor_kl', predictor_kl, global_step) common.generate_tensor_summaries('z_entropy', z_entropy, global_step) common.generate_tensor_summaries('log_prob', log_prob, global_step) common.generate_tensor_summaries('z_mean', z_next.mean(), global_step) common.generate_tensor_summaries('z_stddev', z_next.stddev(), global_step) common.generate_tensor_summaries('prior_mean', prior.mean(), global_step) common.generate_tensor_summaries('prior_stddev', prior.stddev(), global_step) if log_prob_reward_scale == 'auto': coef = tf.stop_gradient(tf.exp(actor_net._log_kl_coefficient)) # pylint: disable=protected-access else: coef = log_prob_reward_scale tf.debugging.check_numerics(tf.reduce_mean(predictor_kl), 'predictor_kl is inf or nan.') tf.debugging.check_numerics(coef, 'coef is inf or nan.') new_reward = experience.reward - coef * predictor_kl[:, None] experience = experience._replace(reward=new_reward) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) # Save the hyperparameters operative_filename = os.path.join(root_dir, 'operative.gin') with tf.compat.v1.gfile.Open(operative_filename, 'w') as f: f.write(gin.operative_config_str()) print(gin.operative_config_str()) global_step_val = global_step.numpy() while global_step_val < num_iterations: start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) env_time_acc += time.time() - start_time train_start_time = time.time() for _ in range(train_steps_per_iteration): train_loss = train_step() train_time_acc += time.time() - train_start_time time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) train_steps_per_sec = (global_step_val - timed_at_step) / train_time_acc logging.info('Train: %.3f steps/sec', train_steps_per_sec) tf.compat.v2.summary.scalar(name='train_steps_per_sec', data=train_steps_per_sec, step=global_step) env_steps_per_sec = (global_step_val - timed_at_step) / env_time_acc logging.info('Env: %.3f steps/sec', env_steps_per_sec) tf.compat.v2.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 train_time_acc = 0 env_time_acc = 0 for train_metric in train_metrics + [kl_metric]: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step_val % eval_interval == 0: start_time = time.time() for name, eval_tf_env, eval_metrics in eval_vec: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics-%s' % name, ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics, prefix=name) logging.info('Evaluation: %d min', (time.time() - start_time) / 60) for prob_dropout in eval_dropout: rpc_utils.eval_dropout_fn(eval_tf_env, actor_net, global_step, prob_dropout=prob_dropout) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)