def testReset(self): batched_avg_return_metric = batched_py_metric.BatchedPyMetric( py_metrics.AverageReturnMetric) tf_avg_return_metric = tf_py_metric.TFPyMetric( batched_avg_return_metric) deps = [] # run one episode for i in range(3): with tf.control_dependencies(deps): traj = tf_avg_return_metric(self._ts[i]) deps = tf.nest.flatten(traj) # reset with tf.control_dependencies(deps): reset_op = tf_avg_return_metric.reset() deps = [reset_op] # run second episode for i in range(3, 6): with tf.control_dependencies(deps): traj = tf_avg_return_metric(self._ts[i]) deps = tf.nest.flatten(traj) # Test result is the reward for the second episode. with tf.control_dependencies(deps): result = tf_avg_return_metric.result() result_ = self.evaluate(result) self.assertEqual(result_, 13)
def testMetricPrefix(self): batched_avg_return_metric = batched_py_metric.BatchedPyMetric( py_metrics.AverageReturnMetric, prefix='CustomPrefix') self.assertEqual(batched_avg_return_metric.prefix, 'CustomPrefix') tf_avg_return_metric = tf_py_metric.TFPyMetric( batched_avg_return_metric) self.assertEqual(tf_avg_return_metric._prefix, 'CustomPrefix')
def _build_metrics(self, buffer_size=10, batch_size=None): python_metrics = [ tf_py_metric.TFPyMetric( py_metrics.AverageReturnMetric(buffer_size=buffer_size, batch_size=batch_size)), tf_py_metric.TFPyMetric( py_metrics.AverageEpisodeLengthMetric(buffer_size=buffer_size, batch_size=batch_size)), ] if batch_size is None: batch_size = 1 tensorflow_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=buffer_size, batch_size=batch_size), tf_metrics.AverageEpisodeLengthMetric(buffer_size=buffer_size, batch_size=batch_size), ] return python_metrics, tensorflow_metrics
def testMetricIsComputedCorrectly(self, num_time_steps, expected_reward): batched_avg_return_metric = batched_py_metric.BatchedPyMetric( py_metrics.AverageReturnMetric) tf_avg_return_metric = tf_py_metric.TFPyMetric(batched_avg_return_metric) deps = [] for i in range(num_time_steps): with tf.control_dependencies(deps): traj = tf_avg_return_metric(self._ts[i]) deps = nest.flatten(traj) with tf.control_dependencies(deps): result = tf_avg_return_metric.result() result_ = self.evaluate(result) self.assertEqual(result_, expected_reward)
def train( root_dir, load_root_dir=None, env_load_fn=None, env_name=None, num_parallel_environments=1, # pylint: disable=unused-argument agent_class=None, initial_collect_random=True, # pylint: disable=unused-argument initial_collect_driver_class=None, collect_driver_class=None, num_global_steps=1000000, train_steps_per_iteration=1, train_metrics=None, # Safety Critic training args train_sc_steps=10, train_sc_interval=300, online_critic=False, # Params for eval run_eval=False, num_eval_episodes=30, eval_interval=1000, eval_metrics_callback=None, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, keep_rb_checkpoint=False, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, early_termination_fn=None, env_metric_factories=None): # pylint: disable=unused-argument """A simple train and eval for SC-SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() train_metrics = train_metrics or [] if run_eval: eval_dir = os.path.join(root_dir, 'eval') eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = env_load_fn(env_name) if not isinstance(tf_env, tf_py_environment.TFPyEnvironment): tf_env = tf_py_environment.TFPyEnvironment(tf_env) if run_eval: eval_py_env = env_load_fn(env_name) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() print('obs spec:', observation_spec) print('action spec:', action_spec) if online_critic: resample_metric = tf_py_metric.TfPyMetric( py_metrics.CounterMetric('unsafe_ac_samples')) tf_agent = agent_class(time_step_spec, action_spec, train_step_counter=global_step, resample_metric=resample_metric) else: tf_agent = agent_class(time_step_spec, action_spec, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. collect_data_spec = tf_agent.collect_data_spec logging.info('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, max_length=1000000) logging.info('RB capacity: %i', replay_buffer.capacity) logging.info('ReplayBuffer Collect data spec: %s', collect_data_spec) agent_observers = [replay_buffer.add_batch] if online_critic: online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, max_length=10000) online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer') online_rb_checkpointer = common.Checkpointer( ckpt_dir=online_rb_ckpt_dir, max_to_keep=1, replay_buffer=online_replay_buffer) clear_rb = common.function(online_replay_buffer.clear) agent_observers.append(online_replay_buffer.add_batch) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] if not online_critic: eval_policy = tf_agent.policy else: eval_policy = tf_agent._safe_policy # pylint: disable=protected-access initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) if not online_critic: collect_policy = tf_agent.collect_policy else: collect_policy = tf_agent._safe_policy # pylint: disable=protected-access train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) safety_critic_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'safety_critic'), safety_critic=tf_agent._safety_critic_network, # pylint: disable=protected-access global_step=global_step) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) if load_root_dir: load_root_dir = os.path.expanduser(load_root_dir) load_train_dir = os.path.join(load_root_dir, 'train') misc.load_pi_ckpt(load_train_dir, tf_agent) # loads tf_agent if load_root_dir is None: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() safety_critic_checkpointer.initialize_or_restore() collect_driver = collect_driver_class(tf_env, collect_policy, observers=agent_observers + train_metrics) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if not rb_checkpointer.checkpoint_exists: logging.info('Performing initial collection ...') common.function( initial_collect_driver_class(tf_env, initial_collect_policy, observers=agent_observers + train_metrics).run)() last_id = replay_buffer._get_last_id() # pylint: disable=protected-access logging.info('Data saved after initial collection: %d steps', last_id) tf.print( replay_buffer._get_rows_for_id(last_id), # pylint: disable=protected-access output_stream=logging.info) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if FLAGS.viz_pm: eval_fig_dir = osp.join(eval_dir, 'figs') if not tf.io.gfile.isdir(eval_fig_dir): tf.io.gfile.makedirs(eval_fig_dir) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2).prefetch(3) iterator = iter(dataset) if online_critic: online_dataset = online_replay_buffer.as_dataset( num_parallel_calls=3, num_steps=2).prefetch(3) online_iterator = iter(online_dataset) @common.function def critic_train_step(): """Builds critic training step.""" experience, buf_info = next(online_iterator) if env_name in [ 'IndianWell', 'IndianWell2', 'IndianWell3', 'DrunkSpider', 'DrunkSpiderShort' ]: safe_rew = experience.observation['task_agn_rew'] else: safe_rew = agents.process_replay_buffer( online_replay_buffer, as_tensor=True) safe_rew = tf.gather(safe_rew, tf.squeeze(buf_info.ids), axis=1) ret = tf_agent.train_sc(experience, safe_rew) clear_rb() return ret @common.function def train_step(): experience, _ = next(iterator) ret = tf_agent.train(experience) return ret if not early_termination_fn: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss') if online_critic: mean_resample_ac = tf.keras.metrics.Mean( name='mean_unsafe_ac_samples') resample_metric.reset() while (global_step.numpy() <= num_global_steps and not early_termination_fn()): # Collect and train. start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) if online_critic: mean_resample_ac(resample_metric.result()) resample_metric.reset() if time_step.is_last(): resample_ac_freq = mean_resample_ac.result() mean_resample_ac.reset_states() tf.compat.v2.summary.scalar(name='unsafe_ac_samples', data=resample_ac_freq, step=global_step) for _ in range(train_steps_per_iteration): train_loss = train_step() mean_train_loss(train_loss.loss) if online_critic: if global_step.numpy() % train_sc_interval == 0: for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # pylint: disable=unused-variable total_loss = mean_train_loss.result() mean_train_loss.reset_states() # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True break else: loss_divergence_counter = 0 time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), total_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) safety_critic_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: if online_critic: online_rb_checkpointer.save(global_step=global_step_val) rb_checkpointer.save(global_step=global_step_val) if run_eval and global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if FLAGS.viz_pm: savepath = 'step{}.png'.format(global_step_val) savepath = osp.join(eval_fig_dir, savepath) misc.record_episode_vis_summary(eval_tf_env, eval_policy, savepath) if not keep_rb_checkpoint: misc.cleanup_checkpoints(rb_ckpt_dir) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy())) return total_loss
def __init__(self, root_dir, env_load_fn=suite_gym.load, env_name='CartPole-v0', num_parallel_environments=1, agent_class=None, num_eval_episodes=30, write_summaries=True, summaries_flush_secs=10, eval_metrics_callback=None, env_metric_factories=None): """Evaluate policy checkpoints as they are produced. Args: root_dir: Main directory for experiment files. env_load_fn: Function to load the environment specified by env_name. env_name: Name of environment to evaluate in. num_parallel_environments: Number of environments to evaluate on in parallel. agent_class: TFAgent class to instantiate for evaluation. num_eval_episodes: Number of episodes to average evaluation over. write_summaries: Whether to write summaries to the file system. summaries_flush_secs: How frequently to flush summaries (in seconds). eval_metrics_callback: A function that will be called with evaluation results for every checkpoint. env_metric_factories: An iterable of metric factories. Use this for eval metrics that needs access to the evaluated environment. A metric factory is a function that takes an eviornment and buffer_size as keyword arguments and returns an instance of py_metric. Raises: ValueError: when num_parallel_environments > num_eval_episodes or agent_class is not set """ if not agent_class: raise ValueError( 'The `agent_class` parameter of Evaluator must be set.') if num_parallel_environments > num_eval_episodes: raise ValueError( 'num_parallel_environments should not be greater than ' 'num_eval_episodes') self._num_eval_episodes = num_eval_episodes self._eval_metrics_callback = eval_metrics_callback # Flag that controls eval cycle. If set, evaluation will exit eval loop # before the max checkpoint number is reached. self._terminate_early = False # Save root dir to self so derived classes have access to it. self._root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(self._root_dir, 'train') self._eval_dir = os.path.join(self._root_dir, 'eval') self._global_step = tf.compat.v1.train.get_or_create_global_step() self._env_name = env_name if num_parallel_environments == 1: eval_env = env_load_fn(env_name) else: eval_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) if isinstance(eval_env, py_environment.PyEnvironment): self._eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env) self._eval_py_env = eval_env else: self._eval_tf_env = eval_env self._eval_py_env = None # Can't generically convert to PyEnvironment. self._eval_metrics = [ tf_metrics.AverageReturnMetric( buffer_size=self._num_eval_episodes, batch_size=self._eval_tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=self._num_eval_episodes, batch_size=self._eval_tf_env.batch_size), ] if env_metric_factories: if not self._eval_py_env: raise ValueError( 'The `env_metric_factories` parameter of Evaluator ' 'can only be used with a PyEnvironment environment.') for metric_factory in env_metric_factories: py_metric = metric_factory(environment=self._eval_py_env, buffer_size=self._num_eval_episodes) self._eval_metrics.append(tf_py_metric.TFPyMetric(py_metric)) if write_summaries: self._eval_summary_writer = tf.compat.v2.summary.create_file_writer( self._eval_dir, flush_millis=summaries_flush_secs * 1000) self._eval_summary_writer.set_as_default() else: self._eval_summary_writer = None environment_specs.set_observation_spec( self._eval_tf_env.observation_spec()) environment_specs.set_action_spec(self._eval_tf_env.action_spec()) # Agent params configured with gin. self._agent = agent_class(self._eval_tf_env.time_step_spec(), self._eval_tf_env.action_spec()) self._eval_policy = greedy_policy.GreedyPolicy(self._agent.policy) self._eval_policy.action = common.function(self._eval_policy.action) # Run the agent on dummy data to force instantiation of the network. Keras # doesn't create variables until you first use the layer. This is needed # for checkpoint restoration to work. dummy_obs = tensor_spec.sample_spec_nest( self._eval_tf_env.observation_spec(), outer_dims=(self._eval_tf_env.batch_size, )) self._eval_policy.action( ts.restart(dummy_obs, batch_size=self._eval_tf_env.batch_size), self._eval_policy.get_initial_state(self._eval_tf_env.batch_size)) self._policy_checkpoint = tf.train.Checkpoint( policy=self._agent.policy, global_step=self._global_step) self._policy_checkpoint_dir = os.path.join(train_dir, 'policy')
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] eval_summary_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_py_env = suite_mujoco.load(env_name) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy( greedy_policy.GreedyPolicy(tf_agent.policy)) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] collect_policy = tf_agent.collect_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=5 * batch_size, num_steps=2).apply(tf.data.experimental.unbatch()).filter( _filter_invalid_transition).batch(batch_size).prefetch( batch_size * 5) dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) global_step_val = sess.run(global_step) if global_step_val == 0: # Initial eval of randomly initialized policy metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): total_loss, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( ############################################## # types of params: # 0: specific to algorithm (gin file 0) # 1: specific to environment (gin file 1) # 2: specific to experiment (gin file 2 + command line) # Note: there are other important params # in eg ModelDistributionNetwork that the gin files specify # like sparse vs dense rewards, latent dimensions, etc. ############################################## # basic params for running/logging experiment root_dir, # 2 experiment_name, # 2 num_iterations=int(1e7), # 2 seed=1, # 2 gpu_allow_growth=False, # 2 gpu_memory_limit=None, # 2 verbose=True, # 2 policy_checkpoint_freq_in_iter=100, # policies needed for future eval # 2 train_checkpoint_freq_in_iter=0, #default don't save # 2 rb_checkpoint_freq_in_iter=0, #default don't save # 2 logging_freq_in_iter=10, # printing to terminal # 2 summary_freq_in_iter=10, # saving to tb # 2 num_images_per_summary=2, # 2 summaries_flush_secs=10, # 2 max_episode_len_override=None, # 2 num_trials_to_render=1, # 2 # environment, action mode, etc. env_name='HalfCheetah-v2', # 1 action_repeat=1, # 1 action_mode='joint_position', # joint_position or joint_delta_position # 1 double_camera=False, # camera input # 1 universe='gym', # default task_reward_dim=1, # default # dims for all networks actor_fc_layers=(256, 256), # 1 critic_obs_fc_layers=None, # 1 critic_action_fc_layers=None, # 1 critic_joint_fc_layers=(256, 256), # 1 num_repeat_when_concatenate=None, # 1 # networks critic_input='state', # 0 actor_input='state', # 0 # specifying tasks and eval episodes_per_trial=1, # 2 num_train_tasks=10, # 2 num_eval_tasks=10, # 2 num_eval_trials=10, # 2 eval_interval=10, # 2 eval_on_holdout_tasks=True, # 2 # data collection/buffer init_collect_trials_per_task=None, # 2 collect_trials_per_task=None, # 2 num_tasks_to_collect_per_iter=5, # 2 replay_buffer_capacity=int(1e5), # 2 # training init_model_train_ratio=0.8, # 2 model_train_ratio=1, # 2 model_train_freq=1, # 2 ac_train_ratio=1, # 2 ac_train_freq=1, # 2 num_tasks_per_train=5, # 2 train_trials_per_task=5, # 2 model_bs_in_steps=256, # 2 ac_bs_in_steps=128, # 2 # default AC learning rates, gamma, etc. target_update_tau=0.005, target_update_period=1, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, model_learning_rate=1e-4, td_errors_loss_fn=functools.partial( tf.compat.v1.losses.mean_squared_error, weights=0.5), gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, log_image_strips=False, stop_model_training=1E10, eval_only=False, # evaluate checkpoints ONLY log_image_observations=False, load_offline_data=False, # whether to use offline data offline_data_dir=None, # replay buffer's dir offline_episode_len=None, # episode len of episodes stored in rb offline_ratio=0, # ratio of data that is from offline buffer ): g = tf.Graph() # register all gym envs max_steps_dict = { "HalfCheetahVel-v0": 50, "SawyerReach-v0": 40, "SawyerReachMT-v0": 40, "SawyerPeg-v0": 40, "SawyerPegMT-v0": 40, "SawyerPegMT4box-v0": 40, "SawyerShelfMT-v0": 40, "SawyerKitchenMT-v0": 40, "SawyerShelfMT-v2": 40, "SawyerButtons-v0": 40, } if max_episode_len_override: max_steps_dict[env_name] = max_episode_len_override register_all_gym_envs(max_steps_dict) # set max_episode_len based on our env max_episode_len = max_steps_dict[env_name] ###################################################### # Calculate additional params ###################################################### # convert to number of steps env_steps_per_trial = episodes_per_trial * max_episode_len real_env_steps_per_trial = episodes_per_trial * (max_episode_len + 1) env_steps_per_iter = num_tasks_to_collect_per_iter * collect_trials_per_task * env_steps_per_trial per_task_collect_steps = collect_trials_per_task * env_steps_per_trial # initial collect + train init_collect_env_steps = num_train_tasks * init_collect_trials_per_task * env_steps_per_trial init_model_train_steps = int(init_collect_env_steps * init_model_train_ratio) # collect + train collect_env_steps_per_iter = num_tasks_to_collect_per_iter * per_task_collect_steps model_train_steps_per_iter = int(env_steps_per_iter * model_train_ratio) ac_train_steps_per_iter = int(env_steps_per_iter * ac_train_ratio) # other global_steps_per_iter = collect_env_steps_per_iter + model_train_steps_per_iter + ac_train_steps_per_iter sample_episodes_per_task = train_trials_per_task * episodes_per_trial # number of episodes to sample from each replay model_bs_in_trials = model_bs_in_steps // real_env_steps_per_trial # assertions that make sure parameters make sense assert model_bs_in_trials > 0, "model batch size need to be at least as big as one full real trial" assert num_tasks_to_collect_per_iter <= num_train_tasks, "when sampling replace=False" assert num_tasks_per_train * train_trials_per_task >= model_bs_in_trials, "not enough data for one batch model train" assert num_tasks_per_train * train_trials_per_task * env_steps_per_trial >= ac_bs_in_steps, "not enough data for one batch ac train" ###################################################### # Print a summary of params ###################################################### MELD_summary_string = f"""\n\n\n ============================================================== ============================================================== \n MELD algorithm summary: * each trial consists of {episodes_per_trial} episodes * episode length: {max_episode_len}, trial length: {env_steps_per_trial} * {num_train_tasks} train tasks, {num_eval_tasks} eval tasks, hold-out: {eval_on_holdout_tasks} * environment: {env_name} For each of {num_train_tasks} tasks: Do {init_collect_trials_per_task} trials of initial collect (total {init_collect_env_steps} env steps) Do {init_model_train_steps} steps of initial model training For i in range(inf): For each of {num_tasks_to_collect_per_iter} randomly selected tasks: Do {collect_trials_per_task} trials of collect (which is {collect_trials_per_task*env_steps_per_trial} env steps per task) (for a total of {num_tasks_to_collect_per_iter*collect_trials_per_task*env_steps_per_trial} env steps in the iteration) if i % model_train_freq(={model_train_freq}): Do {model_train_steps_per_iter} steps of model training - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials. - pick randomly {model_bs_in_trials} trials, train model on whole trials. if i % ac_train_freq(={ac_train_freq}): Do {ac_train_steps_per_iter} steps of ac training - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials. - pick randomly {ac_bs_in_steps} transitions, not including between trial transitions, to train ac. * Other important params: Evaluate policy every {eval_interval} iters, equivalent to {global_steps_per_iter*eval_interval/1000:.1f}k global steps Average evaluation across {num_eval_trials} trials Save summary to tensorboard every {summary_freq_in_iter} iters, equivalent to {global_steps_per_iter*summary_freq_in_iter/1000:.1f}k global steps Checkpoint: - training checkpoint every {train_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*train_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint - policy checkpoint every {policy_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*policy_checkpoint_freq_in_iter//1000}k global steps, keep all checkpoints - replay buffer checkpoint every {rb_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*rb_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint \n ============================================================= ============================================================= """ print(MELD_summary_string) time.sleep(1) ###################################################### # Seed + name + GPU configs + directories for saving ###################################################### np.random.seed(int(seed)) experiment_name += "_seed" + str(seed) gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit) ]) train_eval_dir = get_train_eval_dir(root_dir, universe, env_name, experiment_name) train_dir = os.path.join(train_eval_dir, 'train') eval_dir = os.path.join(train_eval_dir, 'eval') eval_dir_2 = os.path.join(train_eval_dir, 'eval2') ###################################################### # Train and Eval Summary Writers ###################################################### train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_summary_flush_op = eval_summary_writer.flush() eval_logger = Logger(eval_dir_2) ###################################################### # Train and Eval metrics ###################################################### eval_buffer_size = num_eval_trials * episodes_per_trial * max_episode_len # across all eval trials in each evaluation eval_metrics = [] for position in range( episodes_per_trial ): # have metrics for each episode position, to track whether it is learning eval_metrics_pos = [ py_metrics.AverageReturnMetric(name='c_AverageReturnEval_' + str(position), buffer_size=eval_buffer_size), py_metrics.AverageEpisodeLengthMetric( name='f_AverageEpisodeLengthEval_' + str(position), buffer_size=eval_buffer_size), custom_metrics.AverageScoreMetric( name="d_AverageScoreMetricEval_" + str(position), buffer_size=eval_buffer_size), ] eval_metrics.extend(eval_metrics_pos) train_buffer_size = num_train_tasks * episodes_per_trial train_metrics = [ tf_metrics.NumberOfEpisodes(name='NumberOfEpisodes'), tf_metrics.EnvironmentSteps(name='EnvironmentSteps'), tf_py_metric.TFPyMetric( py_metrics.AverageReturnMetric(name="a_AverageReturnTrain", buffer_size=train_buffer_size)), tf_py_metric.TFPyMetric( py_metrics.AverageEpisodeLengthMetric( name="e_AverageEpisodeLengthTrain", buffer_size=train_buffer_size)), tf_py_metric.TFPyMetric( custom_metrics.AverageScoreMetric(name="b_AverageScoreTrain", buffer_size=train_buffer_size)), ] global_step = tf.compat.v1.train.get_or_create_global_step( ) # will be use to record number of model grad steps + ac grad steps + env_step log_cond = get_log_condition_tensor( global_step, init_collect_trials_per_task, env_steps_per_trial, num_train_tasks, init_model_train_steps, collect_trials_per_task, num_tasks_to_collect_per_iter, model_train_steps_per_iter, ac_train_steps_per_iter, summary_freq_in_iter, eval_interval) with tf.compat.v2.summary.record_if(log_cond): ###################################################### # Create env ###################################################### py_env, eval_py_env, train_tasks, eval_tasks = load_environments( universe, action_mode, env_name=env_name, observations_whitelist=['state', 'pixels', "env_info"], action_repeat=action_repeat, num_train_tasks=num_train_tasks, num_eval_tasks=num_eval_tasks, eval_on_holdout_tasks=eval_on_holdout_tasks, return_multiple_tasks=True, ) override_reward_func = None if load_offline_data: py_env.set_task_dict(train_tasks) override_reward_func = py_env.override_reward_func tf_env = tf_py_environment.TFPyEnvironment(py_env, isolation=True) # Get data specs from env time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() original_control_timestep = get_control_timestep(eval_py_env) # fps control_timestep = original_control_timestep * float(action_repeat) render_fps = int(np.round(1.0 / original_control_timestep)) ###################################################### # Latent variable model ###################################################### if verbose: print("-- start constructing model networks --") model_net = ModelDistributionNetwork( double_camera=double_camera, observation_spec=observation_spec, num_repeat_when_concatenate=num_repeat_when_concatenate, task_reward_dim=task_reward_dim, episodes_per_trial=episodes_per_trial, max_episode_len=max_episode_len ) # rest of arguments provided via gin if verbose: print("-- finish constructing AC networks --") ###################################################### # Compressor Network for Actor/Critic # The model's compressor is also used by the AC # compressor function: images --> features ###################################################### compressor_net = model_net.compressor ###################################################### # Specs for Actor and Critic ###################################################### if actor_input == 'state': actor_state_size = observation_spec['state'].shape[0] elif actor_input == 'latentSample': actor_state_size = model_net.state_size elif actor_input == "latentDistribution": actor_state_size = 2 * model_net.state_size # mean and (diagonal) variance of gaussian, of two latents else: raise NotImplementedError actor_input_spec = tensor_spec.TensorSpec((actor_state_size, ), dtype=tf.float32) if critic_input == 'state': critic_state_size = observation_spec['state'].shape[0] elif critic_input == 'latentSample': critic_state_size = model_net.state_size elif critic_input == "latentDistribution": critic_state_size = 2 * model_net.state_size # mean and (diagonal) variance of gaussian, of two latents else: raise NotImplementedError critic_input_spec = tensor_spec.TensorSpec((critic_state_size, ), dtype=tf.float32) ###################################################### # Actor and Critic Networks ###################################################### if verbose: print("-- start constructing Actor and Critic networks --") actor_net = actor_distribution_network.ActorDistributionNetwork( actor_input_spec, action_spec, fc_layer_params=actor_fc_layers, ) critic_net = critic_network.CriticNetwork( (critic_input_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) if verbose: print("-- finish constructing AC networks --") print("-- start constructing agent --") ###################################################### # Create the agent ###################################################### which_posterior_overwrite = None which_reward_overwrite = None meld_agent = MeldAgent( # specs time_step_spec=time_step_spec, action_spec=action_spec, # step counter train_step_counter= global_step, # will count number of model training steps # networks actor_network=actor_net, critic_network=critic_net, model_network=model_net, compressor_network=compressor_net, # optimizers actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), # target update target_update_tau=target_update_tau, target_update_period=target_update_period, # inputs critic_input=critic_input, actor_input=actor_input, # bs stuff model_batch_size=model_bs_in_steps, ac_batch_size=ac_bs_in_steps, # other num_tasks_per_train=num_tasks_per_train, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, control_timestep=control_timestep, num_images_per_summary=num_images_per_summary, task_reward_dim=task_reward_dim, episodes_per_trial=episodes_per_trial, # offline data override_reward_func=override_reward_func, offline_ratio=offline_ratio, ) if verbose: print("-- finish constructing agent --") ###################################################### # Replay buffers + observers to add data to them ###################################################### replay_buffers = [] replay_observers = [] for _ in range(num_train_tasks): replay_buffer_episodic = episodic_replay_buffer.EpisodicReplayBuffer( meld_agent.collect_policy. trajectory_spec, # spec of each point stored in here (i.e. Trajectory) capacity=replay_buffer_capacity, completed_only= True, # in as_dataset, if num_steps is None, this means return full episodes # device='GPU:0', # gpu not supported for some reason begin_episode_fn=lambda traj: traj.is_first()[ 0], # first step of seq we add should be is_first end_episode_fn=lambda traj: traj.is_last()[ 0], # last step of seq we add should be is_last dataset_drop_remainder= True, #`as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items ) replay_buffer = StatefulEpisodicReplayBuffer( replay_buffer_episodic) # adding num_episodes here is bad replay_buffers.append(replay_buffer) replay_observers.append([replay_buffer.add_sequence]) if load_offline_data: # for each task, has a separate replay buffer for relabeled data replay_buffers_withRelabel = [] replay_observers_withRelabel = [] for _ in range(num_train_tasks): replay_buffer_episodic_withRelabel = episodic_replay_buffer.EpisodicReplayBuffer( meld_agent.collect_policy. trajectory_spec, # spec of each point stored in here (i.e. Trajectory) capacity=replay_buffer_capacity, completed_only= True, # in as_dataset, if num_steps is None, this means return full episodes # device='GPU:0', # gpu not supported for some reason begin_episode_fn=lambda traj: traj.is_first()[ 0], # first step of seq we add should be is_first end_episode_fn=lambda traj: traj.is_last()[ 0], # last step of seq we add should be is_last dataset_drop_remainder=True, # `as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items ) replay_buffer_withRelabel = StatefulEpisodicReplayBuffer( replay_buffer_episodic_withRelabel ) # adding num_episodes here is bad replay_buffers_withRelabel.append(replay_buffer_withRelabel) replay_observers_withRelabel.append( [replay_buffer_withRelabel.add_sequence]) if verbose: print("-- finish constructing replay buffers --") print("-- start constructing policies and collect ops --") ###################################################### # Policies ##################################################### # init collect policy (random) init_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) # eval eval_py_policy = py_tf_policy.PyTFPolicy(meld_agent.policy) ################################################################################ # Collect ops : use policies to get data + have the observer put data into corresponding RB ################################################################################ #init collection (with random policy) init_collect_ops = [] for task_idx in range(num_train_tasks): # put init data into the rb + track with the train metric observers = replay_observers[task_idx] + train_metrics # initial collect op init_collect_op = DynamicTrialDriver( tf_env, init_collect_policy, num_trials_to_collect=init_collect_trials_per_task, observers=observers, episodes_per_trial= episodes_per_trial, # policy state will not be reset within these episodes max_episode_len=max_episode_len, ).run() # collect one trial init_collect_ops.append(init_collect_op) # data collection for training (with collect policy) collect_ops = [] for task_idx in range(num_train_tasks): collect_op = DynamicTrialDriver( tf_env, meld_agent.collect_policy, num_trials_to_collect=collect_trials_per_task, observers=replay_observers[task_idx] + train_metrics, # put data into 1st RB + track with 1st pol metrics episodes_per_trial= episodes_per_trial, # policy state will not be reset within these episodes max_episode_len=max_episode_len, ).run() # collect one trial collect_ops.append(collect_op) if verbose: print("-- finish constructing policies and collect ops --") print("-- start constructing replay buffer->training pipeline --") ###################################################### # replay buffer --> dataset --> iterate to get trajecs for training ###################################################### # get some data from all task replay buffers (even though won't actually train on all of them) dataset_iterators = [] all_tasks_trajectories_fromdense = [] for task_idx in range(num_train_tasks): dataset = replay_buffers[task_idx].as_dataset( sample_batch_size= sample_episodes_per_task, # number of episodes to sample num_steps=max_episode_len + 1 ).prefetch( 3 ) # +1 to include the last state: a trajectory with n transition has n+1 states # iterator to go through the data dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) dataset_iterators.append(dataset_iterator) # get sample_episodes_per_task sequences, each of length num_steps trajectories_task_i, _ = dataset_iterator.get_next() all_tasks_trajectories_fromdense.append(trajectories_task_i) if load_offline_data: # have separate dataset for relabel data dataset_iterators_withRelabel = [] all_tasks_trajectories_fromdense_withRelabel = [] for task_idx in range(num_train_tasks): dataset = replay_buffers_withRelabel[task_idx].as_dataset( sample_batch_size= sample_episodes_per_task, # number of episodes to sample num_steps=offline_episode_len + 1 ).prefetch( 3 ) # +1 to include the last state: a trajectory with n transition has n+1 states # iterator to go through the data dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) dataset_iterators_withRelabel.append(dataset_iterator) # get sample_episodes_per_task sequences, each of length num_steps trajectories_task_i, _ = dataset_iterator.get_next() all_tasks_trajectories_fromdense_withRelabel.append( trajectories_task_i) if verbose: print("-- finish constructing replay buffer->training pipeline --") print("-- start constructing model and AC training ops --") ###################################### # Decoding latent samples into rewards ###################################### latent_samples_1_ph = tf.compat.v1.placeholder( dtype=tf.float32, shape=(None, None, meld_agent._model_network.latent1_size)) latent_samples_2_ph = tf.compat.v1.placeholder( dtype=tf.float32, shape=(None, None, meld_agent._model_network.latent2_size)) decode_rews_op = meld_agent._model_network.decode_latents_into_reward( latent_samples_1_ph, latent_samples_2_ph) ###################################### # Model/Actor/Critic train + summary ops ###################################### # train AC on data from replay buffer if load_offline_data: ac_train_op = meld_agent.train_ac_meld( all_tasks_trajectories_fromdense, all_tasks_trajectories_fromdense_withRelabel) else: ac_train_op = meld_agent.train_ac_meld( all_tasks_trajectories_fromdense) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) if verbose: print("-- finish constructing AC training ops --") ############################ # Model train + summary ops ############################ # train model on data from replay buffer if load_offline_data: model_train_op, check_step_types = meld_agent.train_model_meld( all_tasks_trajectories_fromdense, all_tasks_trajectories_fromdense_withRelabel) else: model_train_op, check_step_types = meld_agent.train_model_meld( all_tasks_trajectories_fromdense) model_summary_ops, model_summary_ops_2 = [], [] for summary_op in tf.compat.v1.summary.all_v2_summary_ops(): if summary_op not in summary_ops: model_summary_ops.append(summary_op) if verbose: print("-- finish constructing model training ops --") print("-- start constructing checkpointers --") ######################## # Eval metrics ######################## with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) ######################## # Create savers ######################## train_config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=False) eval_config_saver = gin.tf.GinConfigSaverHook(eval_dir, summarize_config=False) ######################## # Create checkpointers ######################## train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=meld_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=1) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=meld_agent.policy, global_step=global_step, max_to_keep=99999999999 ) # keep many policy checkpoints, in case of future eval rb_checkpointers = [] for buffer_idx in range(len(replay_buffers)): rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffers/', "task" + str(buffer_idx)), max_to_keep=1, replay_buffer=replay_buffers[buffer_idx]) rb_checkpointers.append(rb_checkpointer) if load_offline_data: # for LOADING data not for checkpointing. No new data going in anyways rb_checkpointers_withRelabel = [] for buffer_idx in range(len(replay_buffers_withRelabel)): ckpt_dir = os.path.join(offline_data_dir, "task" + str(buffer_idx)) rb_checkpointer = common.Checkpointer( ckpt_dir=ckpt_dir, max_to_keep=99999999999, replay_buffer=replay_buffers_withRelabel[buffer_idx]) rb_checkpointers_withRelabel.append(rb_checkpointer) # Notice: these replay buffers need to follow the same sequence of tasks as the current one if verbose: print("-- finish constructing checkpointers --") print("-- start main training loop --") with tf.compat.v1.Session() as sess: ######################## # Initialize ######################## if eval_only: sess.run(eval_summary_writer.init()) load_eval_log( train_eval_dir=train_eval_dir, meld_agent=meld_agent, global_step=global_step, sess=sess, eval_metrics=eval_metrics, eval_py_env=eval_py_env, eval_py_policy=eval_py_policy, num_eval_trials=num_eval_trials, max_episode_len=max_episode_len, episodes_per_trial=episodes_per_trial, log_image_strips=log_image_strips, num_trials_to_render=num_trials_to_render, train_tasks= train_tasks, # in case want to eval on a train task eval_tasks=eval_tasks, model_net=model_net, render_fps=render_fps, decode_rews_op=decode_rews_op, latent_samples_1_ph=latent_samples_1_ph, latent_samples_2_ph=latent_samples_2_ph, ) return # Initialize checkpointing train_checkpointer.initialize_or_restore(sess) for rb_checkpointer in rb_checkpointers: rb_checkpointer.initialize_or_restore(sess) if load_offline_data: for rb_checkpointer in rb_checkpointers_withRelabel: rb_checkpointer.initialize_or_restore(sess) # Initialize dataset iterators for dataset_iterator in dataset_iterators: sess.run(dataset_iterator.initializer) if load_offline_data: for dataset_iterator in dataset_iterators_withRelabel: sess.run(dataset_iterator.initializer) # Initialize variables common.initialize_uninitialized_variables(sess) # Initialize summary writers sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) # Initialize savers train_config_saver.after_create_session(sess) eval_config_saver.after_create_session(sess) # Get value of step counter global_step_val = sess.run(global_step) if verbose: print("====== finished initialization ======") ################################################################ # If this is start of new exp (i.e., 1st step) and not continuing old exp # eval rand policy + do initial data collection ################################################################ fresh_start = (global_step_val == 0) if fresh_start: ######################## # Evaluate initial policy ######################## if eval_interval: logging.info( '\n\nDoing evaluation of initial policy on %d trials with randomly sampled tasks', num_eval_trials) perform_eval_and_summaries_meld( eval_metrics, eval_py_env, eval_py_policy, num_eval_trials, max_episode_len, episodes_per_trial, log_image_strips=log_image_strips, num_trials_to_render=num_eval_tasks, eval_tasks=eval_tasks, latent1_size=model_net.latent1_size, latent2_size=model_net.latent2_size, logger=eval_logger, global_step_val=global_step_val, render_fps=render_fps, decode_rews_op=decode_rews_op, latent_samples_1_ph=latent_samples_1_ph, latent_samples_2_ph=latent_samples_2_ph, log_image_observations=log_image_observations, ) sess.run(eval_summary_flush_op) logging.info( 'Done with evaluation of initial (random) policy.\n\n') ######################## # Initial data collection ######################## logging.info( '\n\nGlobal step %d: Beginning init collect op with random policy. Collecting %dx {%d, %d} trials for each task', global_step_val, init_collect_trials_per_task, max_episode_len, episodes_per_trial) init_increment_global_step_op = global_step.assign_add( env_steps_per_trial * init_collect_trials_per_task) for task_idx in range(num_train_tasks): logging.info('on task %d / %d', task_idx + 1, num_train_tasks) py_env.set_task_for_env(train_tasks[task_idx]) sess.run([ init_collect_ops[task_idx], init_increment_global_step_op ]) # incremented gs in granularity of task rb_checkpointer.save(global_step=global_step_val) logging.info('Finished init collect.\n\n') else: logging.info( '\n\nGlobal step %d from loaded experiment: Skipping init collect op.\n\n', global_step_val) ######################### # Create calls ######################### # [1] calls for running the policies to collect training data collect_calls = [] increment_global_step_op = global_step.assign_add( env_steps_per_trial * collect_trials_per_task) for task_idx in range(num_train_tasks): collect_calls.append( sess.make_callable( [collect_ops[task_idx], increment_global_step_op])) # [2] call for doing a training step (A + C) ac_train_step_call = sess.make_callable([ac_train_op, summary_ops]) # [3] call for doing a training step (model) model_train_step_call = sess.make_callable( [model_train_op, check_step_types, model_summary_ops]) # [4] call for evaluating what global_step number we're on global_step_call = sess.make_callable(global_step) # reset keeping track of steps/time timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') with train_summary_writer.as_default( ), tf.compat.v2.summary.record_if(True): steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) ################################# # init model training ################################# if fresh_start: logging.info( '\n\nPerforming %d steps of init model training, each step on %d random tasks', init_model_train_steps, num_tasks_per_train) for i in range(init_model_train_steps): temp_start = time.time() if i % 100 == 0: print(".... init model training ", i, "/", init_model_train_steps) # init model training total_loss_value_model, check_step_types, _ = model_train_step_call( ) if PRINT_TIMING: print("single model train step: ", time.time() - temp_start) if verbose: print("\n\n\n-- start training loop --\n") ################################# # Training Loop ################################# start_time = time.time() for iteration in range(num_iterations): if iteration > 0: g.finalize() # print("\n\n\niter", iteration, sess.run(curr_iter)) print("global step", global_step_call()) logging.info("Iteration: %d, Global step: %d\n", iteration, global_step_val) #################### # collect data #################### logging.info( '\nStarting batch data collection. Collecting %d {%d, %d} trials for each of %d tasks', collect_trials_per_task, max_episode_len, episodes_per_trial, num_tasks_to_collect_per_iter) # randomly select tasks to collect this iteration list_of_collect_task_idxs = np.random.choice( len(train_tasks), num_tasks_to_collect_per_iter, replace=False) for count, task_idx in enumerate(list_of_collect_task_idxs): logging.info('on randomly selected task %d / %d', count + 1, num_tasks_to_collect_per_iter) # set task for the env py_env.set_task_for_env(train_tasks[task_idx]) # collect data with collect policy _, policy_state_val = collect_calls[task_idx]() logging.info('Finish data collection. Global step: %d\n', global_step_call()) #################### # train model #################### if (iteration == 0) or ((iteration % model_train_freq == 0) and (global_step_val < stop_model_training)): logging.info( '\n\nPerforming %d steps of model training, each on %d random tasks', model_train_steps_per_iter, num_tasks_per_train) for model_iter in range(model_train_steps_per_iter): temp_start_2 = time.time() # train model total_loss_value_model, _, _ = model_train_step_call() # print("is logging step", model_iter, sess.run(is_logging_step)) if PRINT_TIMING: print("2: single model train step: ", time.time() - temp_start_2) logging.info('Finish model training. Global step: %d\n', global_step_call()) else: print("SKIPPING MODEL TRAINING") #################### # train actor critic #################### if iteration % ac_train_freq == 0: logging.info( '\n\nPerforming %d steps of AC training, each on %d random tasks \n\n', ac_train_steps_per_iter, num_tasks_per_train) for ac_iter in range(ac_train_steps_per_iter): temp_start_2_ac = time.time() # train ac total_loss_value_ac, _ = ac_train_step_call() if PRINT_TIMING: print("2: single AC train step: ", time.time() - temp_start_2_ac) logging.info('Finish AC training. Global step: %d\n', global_step_call()) # add up time time_acc += time.time() - start_time #################### # logging/summaries #################### ### Eval if eval_interval and (iteration % eval_interval == 0): logging.info( '\n\nDoing evaluation of trained policy on %d trials with randomly sampled tasks', num_eval_trials) perform_eval_and_summaries_meld( eval_metrics, eval_py_env, eval_py_policy, num_eval_trials, max_episode_len, episodes_per_trial, log_image_strips=log_image_strips, num_trials_to_render= num_trials_to_render, # hardcoded: or gif will get too long eval_tasks=eval_tasks, latent1_size=model_net.latent1_size, latent2_size=model_net.latent2_size, logger=eval_logger, global_step_val=global_step_call(), render_fps=render_fps, decode_rews_op=decode_rews_op, latent_samples_1_ph=latent_samples_1_ph, latent_samples_2_ph=latent_samples_2_ph, log_image_observations=log_image_observations, ) ### steps_per_second_summary global_step_val = global_step_call() if logging_freq_in_iter and (iteration % logging_freq_in_iter == 0): # log step number + speed (steps/sec) logging.info( 'step = %d, loss = %f', global_step_val, total_loss_value_ac.loss + total_loss_value_model.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f env_steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) # reset keeping track of steps/time timed_at_step = global_step_val time_acc = 0 ### train_checkpoint if train_checkpoint_freq_in_iter and ( iteration % train_checkpoint_freq_in_iter == 0): train_checkpointer.save(global_step=global_step_val) ### policy_checkpointer if policy_checkpoint_freq_in_iter and ( iteration % policy_checkpoint_freq_in_iter == 0): policy_checkpointer.save(global_step=global_step_val) ### rb_checkpointer if rb_checkpoint_freq_in_iter and ( iteration % rb_checkpoint_freq_in_iter == 0): for rb_checkpointer in rb_checkpointers: rb_checkpointer.save(global_step=global_step_val)
def train_eval( load_root_dir, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args safety_critic_joint_fc_layers=None, safety_critic_lr=3e-4, safety_critic_bias_init_val=None, safety_critic_kernel_scale=None, n_envs=None, target_safety=0.2, fail_weight=None, # Params for train num_global_steps=10000, batch_size=256, # Params for eval run_eval=False, eval_metrics=[], num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, debug_summaries=False, seed=None): if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format( agent_class) agent_class = ALGOS.get(agent_class) train_ckpt_dir = osp.join(load_root_dir, 'train') rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer') py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(py_env) if monitor: vid_path = os.path.join(load_root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) if run_eval: eval_dir = os.path.join(load_root_dir, 'eval') n_envs = n_envs or num_eval_episodes eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric( prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs) ] + [ tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name)) for m in eval_metrics ] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: eval_tf_env.pyenv.seed(seeds) except: pass global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=False) else: tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=False) collect_data_spec = tf_agent.collect_data_spec replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=1000000) replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer) tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent) if agent_class in SAFETY_AGENTS: target_safety = target_safety or tf_agent._target_safety loaded_train_steps = global_step.numpy() logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir, loaded_train_steps) global_step.assign(0) tf.summary.experimental.set_step(global_step) thresholds = [target_safety, 0.5] sc_metrics = [ tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds) ] if seed: tf.compat.v1.set_random_seed(seed) summaries_flush_secs = 10 timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S') offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp) config_saver = gin.tf.GinConfigSaverHook(offline_train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() sc_summary_writer = tf.compat.v2.summary.create_file_writer( offline_train_dir, flush_millis=summaries_flush_secs * 1000) sc_summary_writer.set_as_default() if safety_critic_kernel_scale is not None: ki = tf.compat.v1.variance_scaling_initializer( scale=safety_critic_kernel_scale, mode='fan_in', distribution='truncated_normal') else: ki = tf.compat.v1.keras.initializers.VarianceScaling( scale=1. / 3., mode='fan_in', distribution='uniform') if safety_critic_bias_init_val is not None: bi = tf.constant_initializer(safety_critic_bias_init_val) else: bi = None sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=safety_critic_joint_fc_layers, kernel_initializer=ki, value_bias_initializer=bi, name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') optimizer = tf.keras.optimizers.Adam(safety_critic_lr) sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic') sc_checkpointer = common.Checkpointer( ckpt_dir=sc_net_off_ckpt_dir, safety_critic=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=optimizer, global_step=global_step, max_to_keep=5) sc_checkpointer.initialize_or_restore() resample_counter = py_metrics.CounterMetric('ActionResampleCounter') eval_policy = agents.SafeActorPolicyRSVar( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_net, safety_critic_network=sc_net_off, safety_threshold=target_safety, resample_counter=resample_counter, training=True) dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2, sample_batch_size=batch_size // 2).prefetch(3) data = iter(dataset) full_data = replay_buffer.gather_all() fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool) fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, fail_mask), full_data) init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data) before_fail_mask = tf.roll(fail_mask, [-1], axis=[1]) after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1]) before_fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data) after_init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, after_init_mask), full_data) filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask)) filter_mask = tf.pad( filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]]) n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy() failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=n_failures, dataset_window_shift=1) data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask) sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size // 2, num_steps=2).prefetch(3) neg_data = iter(sc_dataset_neg) get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0] eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step, after_init_step, get_action) losses = [] mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss') target_update = train_utils.get_target_updater(sc_net_off, target_sc_net_off) with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): while global_step.numpy() < num_global_steps: pos_experience, _ = next(data) neg_experience, _ = next(neg_data) exp = data_utils.concat_batches(pos_experience, neg_experience, collect_data_spec) boundary_mask = tf.logical_not(exp.is_boundary()[:, 0]) exp = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, boundary_mask), exp) safe_rew = exp.observation['task_agn_rew'][:, 1] if fail_weight: weights = tf.where(tf.cast(safe_rew, tf.bool), fail_weight / 0.5, (1 - fail_weight) / 0.5) else: weights = None train_loss, sc_loss, lam_loss = train_step( exp, safe_rew, tf_agent, sc_net=sc_net_off, target_sc_net=target_sc_net_off, metrics=sc_metrics, weights=weights, target_safety=target_safety, optimizer=optimizer, target_update=target_update, debug_summaries=debug_summaries) global_step.assign_add(1) global_step_val = global_step.numpy() losses.append( (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy())) mean_loss(train_loss) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='sc_loss', data=sc_loss, step=global_step_val) tf.compat.v2.summary.scalar(name='lam_loss', data=lam_loss, step=global_step_val) if global_step_val % summary_interval == 0: tf.compat.v2.summary.scalar(name=mean_loss.name, data=mean_loss.result(), step=global_step_val) if global_step_val % summary_interval == 0: with tf.name_scope('Metrics'): for metric in sc_metrics: if len(tf.squeeze(metric.result()).shape) == 0: tf.compat.v2.summary.scalar(name=metric.name, data=metric.result(), step=global_step_val) else: fmt_str = '_{}'.format(thresholds[0]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[0], step=global_step_val) fmt_str = '_{}'.format(thresholds[1]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[1], step=global_step_val) metric.reset_states() if global_step_val % eval_interval == 0: eval_sc(sc_net_off, step=global_step_val) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries( train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and global_step_val % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action( monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug( 'saved rollout at timestep %d, rollout length: %d, %4.2f sec', global_step_val, ep_len, time.time() - monitor_start) if global_step_val % train_checkpoint_interval == 0: sc_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def train_eval( root_dir, experiment_name, train_eval_dir=None, universe='gym', env_name='HalfCheetah-v2', domain_name='cheetah', task_name='run', action_repeat=1, num_iterations=int(1e7), actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), model_network_ctor=model_distribution_network.ModelDistributionNetwork, critic_input='state', actor_input='state', compressor_descriptor='preprocessor_32_3', # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=int(1e5), # increase if necessary since buffers with images are huge # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, model_train_steps_per_iteration=1, initial_model_train_steps=100000, batch_size=256, model_batch_size=32, sequence_length=4, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, model_learning_rate=1e-4, td_errors_loss_fn=functools.partial( tf.compat.v1.losses.mean_squared_error, weights=0.5), gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging num_images_per_summary=1, train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=0, # enable if necessary since buffers with images are huge log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, gpu_allow_growth=False, gpu_memory_limit=None): """A simple train and eval for SLAC.""" gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit) ]) if train_eval_dir is None: train_eval_dir = get_train_eval_dir(root_dir, universe, env_name, domain_name, task_name, experiment_name) train_dir = os.path.join(train_eval_dir, 'train') eval_dir = os.path.join(train_eval_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthEvalPolicy', buffer_size=num_eval_episodes), ] eval_greedy_metrics = [ py_metrics.AverageReturnMetric(name='AverageReturnEvalGreedyPolicy', buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthEvalGreedyPolicy', buffer_size=num_eval_episodes), ] eval_summary_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. trainable_model = model_train_steps_per_iteration != 0 state_only = (actor_input == 'state' and critic_input == 'state' and not trainable_model and initial_model_train_steps == 0) # Save time from unnecessarily rendering observations. observations_whitelist = ['state'] if state_only else None py_env, eval_py_env = load_environments( universe, env_name=env_name, domain_name=domain_name, task_name=task_name, observations_whitelist=observations_whitelist, action_repeat=action_repeat) tf_env = tf_py_environment.TFPyEnvironment(py_env, isolation=True) original_control_timestep = get_control_timestep(eval_py_env) control_timestep = original_control_timestep * float(action_repeat) fps = int(np.round(1.0 / control_timestep)) render_fps = int(np.round(1.0 / original_control_timestep)) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() if model_train_steps_per_iteration not in (0, train_steps_per_iteration): raise NotImplementedError model_net = model_network_ctor(observation_spec, action_spec) if compressor_descriptor == 'model': compressor_net = model_net.compressor elif re.match('preprocessor_(\d+)_(\d+)', compressor_descriptor): m = re.match('preprocessor_(\d+)_(\d+)', compressor_descriptor) filters, n_layers = m.groups() filters = int(filters) n_layers = int(n_layers) compressor_net = compressor_network.Preprocessor(filters, n_layers=n_layers) elif re.match('compressor_(\d+)', compressor_descriptor): m = re.match('compressor_(\d+)', compressor_descriptor) filters, = m.groups() filters = int(filters) compressor_net = compressor_network.Compressor(filters) elif re.match('softlearning_(\d+)_(\d+)', compressor_descriptor): m = re.match('softlearning_(\d+)_(\d+)', compressor_descriptor) filters, n_layers = m.groups() filters = int(filters) n_layers = int(n_layers) compressor_net = compressor_network.SoftlearningPreprocessor( filters, n_layers=n_layers) elif compressor_descriptor == 'd4pg': compressor_net = compressor_network.D4pgPreprocessor() else: raise NotImplementedError(compressor_descriptor) actor_state_size = 0 for _actor_input in actor_input.split('__'): if _actor_input == 'state': state_size, = observation_spec['state'].shape actor_state_size += state_size elif _actor_input == 'latent': actor_state_size += model_net.state_size elif _actor_input == 'feature': actor_state_size += compressor_net.feature_size elif _actor_input in ('sequence_feature', 'sequence_action_feature'): actor_state_size += compressor_net.feature_size * sequence_length if _actor_input == 'sequence_action_feature': actor_state_size += tf.compat.dimension_value( action_spec.shape[0]) * (sequence_length - 1) else: raise NotImplementedError actor_input_spec = tensor_spec.TensorSpec((actor_state_size, ), dtype=tf.float32) critic_state_size = 0 for _critic_input in critic_input.split('__'): if _critic_input == 'state': state_size, = observation_spec['state'].shape critic_state_size += state_size elif _critic_input == 'latent': critic_state_size += model_net.state_size elif _critic_input == 'feature': critic_state_size += compressor_net.feature_size elif _critic_input in ('sequence_feature', 'sequence_action_feature'): critic_state_size += compressor_net.feature_size * sequence_length if _critic_input == 'sequence_action_feature': critic_state_size += tf.compat.dimension_value( action_spec.shape[0]) * (sequence_length - 1) else: raise NotImplementedError critic_input_spec = tensor_spec.TensorSpec((critic_state_size, ), dtype=tf.float32) actor_net = actor_distribution_network.ActorDistributionNetwork( actor_input_spec, action_spec, fc_layer_params=actor_fc_layers) critic_net = critic_network.CriticNetwork( (critic_input_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = slac_agent.SlacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, model_network=model_net, compressor_network=compressor_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), sequence_length=sequence_length, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, trainable_model=trainable_model, critic_input=critic_input, actor_input=actor_input, model_batch_size=model_batch_size, control_timestep=control_timestep, num_images_per_summary=num_images_per_summary, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) eval_greedy_py_policy = py_tf_policy.PyTFPolicy( greedy_policy.GreedyPolicy(tf_agent.policy)) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric( py_metrics.AverageReturnMetric(buffer_size=1)), tf_py_metric.TFPyMetric( py_metrics.AverageEpisodeLengthMetric(buffer_size=1)), ] collect_policy = tf_agent.collect_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) initial_policy_state = initial_collect_policy.get_initial_state(1) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run( policy_state=initial_policy_state) policy_state = collect_policy.get_initial_state(1) collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run( policy_state=policy_state) # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[-2] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=sequence_length + 1).unbatch().filter(_filter_invalid_transition).batch( batch_size, drop_remainder=True).prefetch(3) dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) if initial_model_train_steps: with tf.name_scope('initial'): model_train_op = tf_agent.train_model(trajectories) model_summary_ops = [] for summary_op in tf.compat.v1.summary.all_v2_summary_ops(): if summary_op not in summary_ops: model_summary_ops.append(summary_op) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics + eval_greedy_metrics: eval_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if eval_interval: eval_images_ph = tf.compat.v1.placeholder(dtype=tf.uint8, shape=[None] * 5) eval_images_summary = gif_utils.gif_summary_v2( 'ObservationVideoEvalPolicy', eval_images_ph, 1, fps) eval_render_images_summary = gif_utils.gif_summary_v2( 'VideoEvalPolicy', eval_images_ph, 1, render_fps) eval_greedy_images_summary = gif_utils.gif_summary_v2( 'ObservationVideoEvalGreedyPolicy', eval_images_ph, 1, fps) eval_greedy_render_images_summary = gif_utils.gif_summary_v2( 'VideoEvalGreedyPolicy', eval_images_ph, 1, render_fps) train_config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=False) eval_config_saver = gin.tf.GinConfigSaverHook(eval_dir, summarize_config=False) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=2) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step, max_to_keep=2) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) train_config_saver.after_create_session(sess) eval_config_saver.after_create_session(sess) global_step_val = sess.run(global_step) if global_step_val == 0: if eval_interval: # Initial eval of randomly initialized policy for _eval_metrics, _eval_py_policy, \ _eval_render_images_summary, _eval_images_summary in ( (eval_metrics, eval_py_policy, eval_render_images_summary, eval_images_summary), (eval_greedy_metrics, eval_greedy_py_policy, eval_greedy_render_images_summary, eval_greedy_images_summary)): compute_summaries( _eval_metrics, eval_py_env, _eval_py_policy, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, images_ph=eval_images_ph, render_images_summary=_eval_render_images_summary, images_summary=_eval_images_summary) sess.run(eval_summary_flush_op) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) policy_state_val = sess.run(policy_state) collect_call = sess.make_callable(collect_op, feed_list=[policy_state]) train_step_call = sess.make_callable([train_op, summary_ops]) if initial_model_train_steps: model_train_step_call = sess.make_callable( [model_train_op, model_summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') # steps_per_second summary should always be recorded since it's only called every log_interval steps with tf.compat.v2.summary.record_if(True): steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for iteration in range(global_step_val, initial_model_train_steps + num_iterations): start_time = time.time() if iteration < initial_model_train_steps: total_loss_val, _ = model_train_step_call() else: time_step_val, policy_state_val = collect_call( policy_state_val) for _ in range(train_steps_per_iteration): total_loss_val, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if log_interval and global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss_val.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if (train_checkpoint_interval and global_step_val % train_checkpoint_interval == 0): train_checkpointer.save(global_step=global_step_val) if iteration < initial_model_train_steps: continue if eval_interval and global_step_val % eval_interval == 0: for _eval_metrics, _eval_py_policy, \ _eval_render_images_summary, _eval_images_summary in ( (eval_metrics, eval_py_policy, eval_render_images_summary, eval_images_summary), (eval_greedy_metrics, eval_greedy_py_policy, eval_greedy_render_images_summary, eval_greedy_images_summary)): compute_summaries( _eval_metrics, eval_py_env, _eval_py_policy, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, images_ph=eval_images_ph, render_images_summary=_eval_render_images_summary, images_summary=_eval_images_summary) sess.run(eval_summary_flush_op) if (policy_checkpoint_interval and global_step_val % policy_checkpoint_interval == 0): policy_checkpointer.save(global_step=global_step_val) if (rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0): rb_checkpointer.save(global_step=global_step_val)
def train( root_dir, env_load_fn=suite_gym.load, env_name='CartPole-v0', env_name_eval=None, num_parallel_environments=1, agent_class=None, initial_collect_random=True, initial_collect_driver_class=None, collect_driver_class=None, num_global_steps=100000, train_steps_per_iteration=1, clear_rb_after_train_steps=None, # Defaults to True for ON_POLICY_AGENTS train_metrics=None, # Params for eval run_eval=False, num_eval_episodes=30, eval_interval=1000, eval_metrics_callback=None, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, keep_rb_checkpoint=False, train_sequence_length=1, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, early_termination_fn=None, env_metric_factories=None): eval_interval_counter = IntervalCounter(eval_interval) train_checkpoint_interval_counter = IntervalCounter( train_checkpoint_interval) policy_checkpoint_interval_counter = IntervalCounter( policy_checkpoint_interval) rb_checkpoint_interval_counter = IntervalCounter(rb_checkpoint_interval) log_interval_counter = IntervalCounter(log_interval) summary_interval_counter = IntervalCounterTf(summary_interval) if not agent_class: raise ValueError( 'The `agent_class` parameter of trainer.train must be set.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') if not tf.io.gfile.exists(saved_model_dir): tf.io.gfile.makedirs(saved_model_dir) train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() def make_possibly_parallel_environment(env_name_): """Returns a function creating env_name_, possibly a parallel one.""" if num_parallel_environments == 1: return env_load_fn(env_name_) else: return parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name_)] * num_parallel_environments) def make_tf_py_envs(env): """Convert env to tf if needed.""" if isinstance(env, py_environment.PyEnvironment): tf_env = tf_py_environment.TFPyEnvironment(env) py_env = env else: tf_env = env py_env = None # Can't generically convert to PyEnvironment. return tf_env, py_env eval_py_env = None if run_eval: if env_name_eval is None: env_name_eval = env_name eval_env = make_possibly_parallel_environment(env_name_eval) eval_tf_env, eval_py_env = make_tf_py_envs(eval_env) eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(batch_size=eval_tf_env.batch_size, buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( batch_size=eval_tf_env.batch_size, buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() # should_summary = tf.constant(summary_interval_counter.should_trigger(global_step.numpy())) # print("should_summary", should_summary) with tf.compat.v2.summary.record_if( lambda: summary_interval_counter.should_trigger(global_step)): env = make_possibly_parallel_environment(env_name) tf_env, py_env = make_tf_py_envs(env) environment_specs.set_observation_spec(tf_env.observation_spec()) environment_specs.set_action_spec(tf_env.action_spec()) # Agent params configured with gin. agent = agent_class(tf_env.time_step_spec(), tf_env.action_spec(), train_step_counter=global_step) agent.initialize() if clear_rb_after_train_steps is None: # Default is to clear RB for ON_POLICY_AGENTS, only. clear_rb_after_train_steps = isinstance(agent, ON_POLICY_AGENTS) if run_eval: eval_policy = greedy_policy.GreedyPolicy(agent.policy) if not train_metrics: train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(batch_size=tf_env.batch_size, buffer_size=log_interval * tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( batch_size=tf_env.batch_size, buffer_size=log_interval * tf_env.batch_size), ] else: train_metrics = list(train_metrics) if env_metric_factories: for metric_factory in env_metric_factories: py_metric = metric_factory(environment=py_env) train_metrics.append(tf_py_metric.TFPyMetric(py_metric)) logging.info('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( agent.collect_data_spec) logging.info('RB capacity: %i', replay_buffer.capacity) agent_observers = [replay_buffer.add_batch] initial_collect_policy = agent.collect_policy if initial_collect_random: initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec(), info_spec=agent.collect_policy.info_spec) collect_policy = agent.collect_policy collect_driver = collect_driver_class(tf_env, collect_policy, observers=agent_observers + train_metrics) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, max_to_keep=1, agent=agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), max_to_keep=None, policy=agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) saved_model = policy_saver.PolicySaver(greedy_policy.GreedyPolicy( agent.policy), train_step=global_step) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() collect_driver.run = common.function(collect_driver.run) agent.train = common.function(agent.train) if not rb_checkpointer.checkpoint_exists: logging.info('Performing initial collection ...') common.function( initial_collect_driver_class(tf_env, initial_collect_policy, observers=agent_observers + train_metrics).run)() if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) # This is only used for PPO Agents. # The dataset is repeated for `train_steps_per_iteration` which represents # the number of epochs we loop through during training. def get_data_iter_repeated(replay_buffer): dataset = replay_buffer.as_dataset( sample_batch_size=num_parallel_environments, num_steps=train_sequence_length + 1, num_parallel_calls=3, single_deterministic_pass=True).repeat( train_steps_per_iteration) if len([1 for _ in dataset]) == 0: logging.warning('PPO Agent replay buffer as dataset is empty') return iter(dataset) # For off policy agents, one iterator is created for the entire training # process. This is different from PPO agents whose iterators are reset # in the training loop. if not isinstance(agent, ON_POLICY_AGENTS): dataset = replay_buffer.as_dataset( num_parallel_calls=3, num_steps=train_sequence_length + 1).prefetch(3) iterator = iter(dataset) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 def save_policy(global_step_value): """Saves policy using both checkpoint saver and saved model.""" policy_checkpointer.save(global_step=global_step_value) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_value).zfill(8)) saved_model.save(saved_model_path) if global_step.numpy() == 0: # Save an initial checkpoint so the evaluator runs for global_step=0. save_policy(global_step.numpy()) @common.function def train_step(data_iterator): experience = next(data_iterator)[0] return agent.train(experience) @common.function def train_with_gather_all(): return agent.train(replay_buffer.gather_all()) if not early_termination_fn: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 # Save operative config as late as possible to include used configurables. if global_step.numpy() == 0: config_filename = os.path.join( train_dir, 'operative_config-{}.gin'.format(global_step.numpy())) with tf.io.gfile.GFile(config_filename, 'wb') as f: f.write(gin.operative_config_str()) logging.info('Training ...') while (global_step.numpy() <= num_global_steps and not early_termination_fn()): # Collect and train. start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state) if isinstance(agent, PPO_AGENTS): iterator = get_data_iter_repeated(replay_buffer) for _ in range(train_steps_per_iteration): if isinstance(agent, REINFORCE_AGENTS): total_loss = train_with_gather_all() else: total_loss = train_step(iterator) total_loss = total_loss.loss # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True break else: loss_divergence_counter = 0 if clear_rb_after_train_steps: replay_buffer.clear() time_acc += time.time() - start_time should_log = log_interval_counter.should_trigger( global_step.numpy()) if should_log: logging.info('step = %d, loss = %f', global_step.numpy(), total_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if should_log: hpt = hypertune.HyperTune() hpt.report_hyperparameter_tuning_metric( hyperparameter_metric_tag=train_metric.name, metric_value=train_metric.result(), global_step=global_step) print("Reported", train_metric.name, global_step.numpy()) if train_checkpoint_interval_counter.should_trigger( global_step.numpy()): train_checkpointer.save(global_step=global_step.numpy()) print("train_checkpoint", global_step.numpy()) if policy_checkpoint_interval_counter.should_trigger( global_step.numpy()): save_policy(global_step.numpy()) print("policy_checkpoint", global_step.numpy()) if rb_checkpoint_interval_counter.should_trigger( global_step.numpy()): rb_checkpointer.save(global_step=global_step.numpy()) print("rb_checkpoint", global_step.numpy()) if run_eval and eval_interval_counter.should_trigger( global_step.numpy()): results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if not keep_rb_checkpoint: cleanup_checkpoints(rb_ckpt_dir) if py_env: py_env.close() if eval_py_env: eval_py_env.close() # Save final operative config that will also have all configurables used in # the training loop for the first time. config_filename = os.path.join(train_dir, 'operative_config-final.gin') with tf.io.gfile.GFile(config_filename, 'wb') as f: f.write(gin.operative_config_str()) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy()))
def train_eval( root_dir, load_root_dir=None, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, initial_collect_driver_class=None, collect_driver_class=None, online_driver_class=dynamic_episode_driver.DynamicEpisodeDriver, num_global_steps=1000000, train_steps_per_iteration=1, train_metrics=None, eval_metrics=None, train_metrics_callback=None, # Params for SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args train_sc_steps=10, train_sc_interval=1000, online_critic=False, n_envs=None, finetune_sc=False, # Ensemble Critic training args n_critics=30, critic_learning_rate=3e-4, # Wcpg Critic args critic_preprocessing_layer_size=256, actor_preprocessing_layer_size=256, # Params for train batch_size=256, # Params for eval run_eval=False, num_eval_episodes=1, max_episode_len=500, eval_interval=10000, eval_metrics_callback=None, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, keep_rb_checkpoint=False, log_interval=1000, summary_interval=1000, monitor_interval=1000, summaries_flush_secs=10, early_termination_fn=None, debug_summaries=False, seed=None, eager_debug=False, env_metric_factories=None): # pylint: disable=unused-argument """A simple train and eval for SC-SAC.""" n_envs = n_envs or num_eval_episodes root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() train_metrics = train_metrics or [] eval_metrics = eval_metrics or [] sc_metrics = eval_metrics or [] if online_critic: sc_dir = os.path.join(root_dir, 'sc') sc_summary_writer = tf.compat.v2.summary.create_file_writer( sc_dir, flush_millis=summaries_flush_secs * 1000) sc_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=n_envs, name='SafeAverageReturn'), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=n_envs, name='SafeAverageEpisodeLength') ] + [tf_py_metric.TFPyMetric(m) for m in sc_metrics] sc_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: sc_tf_env.seed([seed + i for i in range(n_envs)]) if run_eval: eval_dir = os.path.join(root_dir, 'eval') eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=n_envs), ] + [tf_py_metric.TFPyMetric(m) for m in eval_metrics] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: eval_tf_env.seed([seed + n_envs + i for i in range(n_envs)]) if monitor: vid_path = os.path.join(root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(py_env) if seed: tf_env.seed(seed + 2 * n_envs + i for i in range(n_envs)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() logging.debug('obs spec: %s', observation_spec) logging.debug('action spec: %s', action_spec) if agent_class: #is not wcpg_agent.WcpgAgent: actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) else: alpha_spec = tensor_spec.BoundedTensorSpec(shape=(), dtype=tf.float32, minimum=0., maximum=1., name='alpha') input_tensor_spec = (observation_spec, action_spec, alpha_spec) critic_preprocessing_layers = ( tf.keras.layers.Dense(critic_preprocessing_layer_size), tf.keras.layers.Dense(critic_preprocessing_layer_size), tf.keras.layers.Lambda(lambda x: x)) critic_net = agents.DistributionalCriticNetwork( input_tensor_spec, joint_fc_layer_params=critic_joint_fc_layers) actor_preprocessing_layers = ( tf.keras.layers.Dense(actor_preprocessing_layer_size), tf.keras.layers.Dense(actor_preprocessing_layer_size), tf.keras.layers.Lambda(lambda x: x)) actor_net = agents.WcpgActorNetwork( input_tensor_spec, preprocessing_layers=actor_preprocessing_layers) if agent_class in SAFETY_AGENTS: safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=debug_summaries) elif agent_class is ensemble_sac_agent.EnsembleSacAgent: critic_nets, critic_optimizers = [critic_net], [ tf.keras.optimizers.Adam(critic_learning_rate) ] for _ in range(n_critics - 1): critic_nets.append( agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers)) critic_optimizers.append( tf.keras.optimizers.Adam(critic_learning_rate)) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_nets, critic_optimizers=critic_optimizers, debug_summaries=debug_summaries) else: # assume is using SacAgent logging.debug(critic_net.input_tensor_spec) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=debug_summaries) tf_agent.initialize() # Make the replay buffer. collect_data_spec = tf_agent.collect_data_spec logging.debug('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=1000000) logging.debug('RB capacity: %i', replay_buffer.capacity) logging.debug('ReplayBuffer Collect data spec: %s', collect_data_spec) agent_observers = [replay_buffer.add_batch] if online_critic: online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=max_episode_len * num_eval_episodes) agent_observers.append(online_replay_buffer.add_batch) online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer') online_rb_checkpointer = common.Checkpointer( ckpt_dir=online_rb_ckpt_dir, max_to_keep=1, replay_buffer=online_replay_buffer) clear_rb = online_replay_buffer.clear train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] if not online_critic: eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy else: eval_policy = tf_agent.policy # pylint: disable=protected-access collect_policy = tf_agent.collect_policy # pylint: disable=protected-access online_collect_policy = tf_agent._safe_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) if agent_class in SAFETY_AGENTS: safety_critic_checkpointer = common.Checkpointer( ckpt_dir=sc_dir, safety_critic=tf_agent._safety_critic_network, # pylint: disable=protected-access global_step=global_step) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) if load_root_dir: load_root_dir = os.path.expanduser(load_root_dir) load_train_dir = os.path.join(load_root_dir, 'train') misc.load_pi_ckpt(load_train_dir, tf_agent) # loads tf_agent if load_root_dir is None: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if agent_class in SAFETY_AGENTS: safety_critic_checkpointer.initialize_or_restore() env_metrics = [] if env_metric_factories: for env_metric in env_metric_factories: env_metrics.append( tf_py_metric.TFPyMetric(env_metric([py_env.gym]))) # TODO: get env factory with parallel py envs # if run_eval: # eval_metrics.append(env_metric([env.gym for env in eval_tf_env.pyenv._envs])) # if online_critic: # sc_metrics.append(env_metric([env.gym for env in sc_tf_env.pyenv._envs])) collect_driver = collect_driver_class(tf_env, collect_policy, observers=agent_observers + train_metrics + env_metrics) if online_critic: logging.debug('online driver class: %s', online_driver_class) if online_driver_class is safe_dynamic_episode_driver.SafeDynamicEpisodeDriver: online_temp_buffer = episodic_replay_buffer.EpisodicReplayBuffer( collect_data_spec) online_temp_buffer_stateful = episodic_replay_buffer.StatefulEpisodicReplayBuffer( online_temp_buffer, num_episodes=num_eval_episodes) online_driver = safe_dynamic_episode_driver.SafeDynamicEpisodeDriver( sc_tf_env, online_collect_policy, online_temp_buffer, online_replay_buffer, observers=[online_temp_buffer_stateful.add_batch] + sc_metrics, num_episodes=num_eval_episodes) else: online_driver = online_driver_class( sc_tf_env, online_collect_policy, observers=[online_replay_buffer.add_batch] + sc_metrics, num_episodes=num_eval_episodes) online_driver.run = common.function(online_driver.run) if not eager_debug: config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() if agent_class is sac_agent.SacAgent: collect_driver.run = common.function(collect_driver.run) if eager_debug: tf.config.experimental_run_functions_eagerly(True) if not rb_checkpointer.checkpoint_exists: logging.info('Performing initial collection ...') initial_collect_driver_class(tf_env, initial_collect_policy, observers=agent_observers + train_metrics + env_metrics).run() last_id = replay_buffer._get_last_id() # pylint: disable=protected-access logging.info('Data saved after initial collection: %d steps', last_id) if online_critic: last_id = online_replay_buffer._get_last_id() # pylint: disable=protected-access logging.debug( 'Data saved in online buffer after initial collection: %d steps', last_id) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) if online_critic: online_dataset = online_replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) online_iterator = iter(online_dataset) critic_metrics = [ tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.TruePositives(name='safety_critic_tp'), tf.keras.metrics.FalsePositives(name='safety_critic_fp'), tf.keras.metrics.TrueNegatives(name='safety_critic_tn'), tf.keras.metrics.FalseNegatives(name='safety_critic_fn'), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc') ] @common.function def critic_train_step(): """Builds critic training step.""" start_time = time.time() experience, buf_info = next(online_iterator) if env_name.split('-')[0] in SAFETY_ENVS: safe_rew = experience.observation['task_agn_rew'][:, 1] else: safe_rew = misc.process_replay_buffer(online_replay_buffer, as_tensor=True) safe_rew = tf.gather(safe_rew, tf.squeeze(buf_info.ids), axis=1) ret = tf_agent.train_sc(experience, safe_rew, metrics=critic_metrics, weights=None) logging.debug('critic train step: {} sec'.format(time.time() - start_time)) return ret @common.function def train_step(): experience, _ = next(iterator) ret = tf_agent.train(experience) return ret if not early_termination_fn: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss') if online_critic: logging.debug('starting safety critic pretraining') safety_eps = tf_agent._safe_policy._safety_threshold tf_agent._safe_policy._safety_threshold = 0.6 resample_counter = online_collect_policy._resample_counter mean_resample_ac = tf.keras.metrics.Mean( name='mean_unsafe_ac_freq') # don't fine-tune safety critic if (global_step.numpy() == 0 and load_root_dir is None): for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # pylint: disable=unused-variable tf_agent._safe_policy._safety_threshold = safety_eps logging.debug('starting policy pretraining') while (global_step.numpy() <= num_global_steps and not early_termination_fn()): # Collect and train. start_time = time.time() current_step = global_step.numpy() if online_critic: mean_resample_ac(resample_counter.result()) resample_counter.reset() if time_step is None or time_step.is_last(): resample_ac_freq = mean_resample_ac.result() mean_resample_ac.reset_states() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) logging.debug('policy eval: {} sec'.format(time.time() - start_time)) train_time = time.time() for _ in range(train_steps_per_iteration): train_loss = train_step() mean_train_loss(train_loss.loss) if current_step == 0: logging.debug('train policy: {} sec'.format(time.time() - train_time)) if online_critic and current_step % train_sc_interval == 0: batch_time_step = sc_tf_env.reset() batch_policy_state = online_collect_policy.get_initial_state( sc_tf_env.batch_size) online_driver.run(time_step=batch_time_step, policy_state=batch_policy_state) for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # pylint: disable=unused-variable metric_utils.log_metrics(sc_metrics) with sc_summary_writer.as_default(): for sc_metric in sc_metrics: sc_metric.tf_summaries(train_step=global_step, step_metrics=sc_metrics[:2]) tf.compat.v2.summary.scalar(name='resample_ac_freq', data=resample_ac_freq, step=global_step) total_loss = mean_train_loss.result() mean_train_loss.reset_states() # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True logging.debug( 'Loss diverged, critic_loss: %s, actor_loss: %s, alpha_loss: %s', train_loss.extra.critic_loss, train_loss.extra.actor_loss, train_loss.extra.alpha_loss) break else: loss_divergence_counter = 0 time_acc += time.time() - start_time if current_step % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), total_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 train_results = [] for train_metric in train_metrics: if isinstance(train_metric, (metrics.AverageEarlyFailureMetric, metrics.AverageFallenMetric, metrics.AverageSuccessMetric)): # Plot failure as a fn of return train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:3]) else: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) train_results.append( (train_metric.name, train_metric.result().numpy())) if env_metrics: for env_metric in env_metrics: env_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) train_results.append( (env_metric.name, env_metric.result().numpy())) if online_critic: for critic_metric in critic_metrics: train_results.append( (critic_metric.name, critic_metric.result().numpy())) critic_metric.reset_states() if train_metrics_callback is not None: train_metrics_callback(collections.OrderedDict(train_results), global_step.numpy()) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if agent_class in SAFETY_AGENTS: safety_critic_checkpointer.save( global_step=global_step_val) if rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0: if online_critic: online_rb_checkpointer.save(global_step=global_step_val) rb_checkpointer.save(global_step=global_step_val) elif online_critic: clear_rb() if run_eval and global_step_val % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) if monitor and current_step % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action( monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 monitor_py_env.reset() logging.debug( 'saved rollout at timestep {}, rollout length: {}, {} sec'. format(global_step_val, ep_len, time.time() - monitor_start)) logging.debug('iteration time: {} sec'.format(time.time() - start_time)) if not keep_rb_checkpoint: misc.cleanup_checkpoints(rb_ckpt_dir) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy())) return total_loss
def train_eval( root_dir, load_root_dir=None, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, initial_collect_driver_class=None, collect_driver_class=None, online_driver_class=dynamic_episode_driver.DynamicEpisodeDriver, num_global_steps=1000000, rb_size=None, train_steps_per_iteration=1, train_metrics=None, eval_metrics=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args sc_rb_size=None, target_safety=None, train_sc_steps=10, train_sc_interval=1000, online_critic=False, n_envs=None, finetune_sc=False, pretraining=True, lambda_schedule_nsteps=0, lambda_initial=0., lambda_final=1., kstep_fail=0, # Ensemble Critic training args num_critics=None, critic_learning_rate=3e-4, # Wcpg Critic args critic_preprocessing_layer_size=256, # Params for train batch_size=256, # Params for eval run_eval=False, num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, keep_rb_checkpoint=False, log_interval=1000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, early_termination_fn=None, debug_summaries=False, seed=None, eager_debug=False, env_metric_factories=None, wandb=False): # pylint: disable=unused-argument """train and eval script for SQRL.""" if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(agent_class) agent_class = ALGOS.get(agent_class) n_envs = n_envs or num_eval_episodes root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') # =====================================================================# # Setup summary metrics, file writers, and create env # # =====================================================================# train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() train_metrics = train_metrics or [] eval_metrics = eval_metrics or [] updating_sc = online_critic and (not load_root_dir or finetune_sc) logging.debug('updating safety critic: %s', updating_sc) if seed: tf.compat.v1.set_random_seed(seed) if agent_class in SAFETY_AGENTS: if online_critic: sc_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * n_envs )) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: sc_tf_env.pyenv.seed(seeds) except: pass if run_eval: eval_dir = os.path.join(root_dir, 'eval') eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes, batch_size=n_envs), ] + [tf_py_metric.TFPyMetric(m) for m in eval_metrics] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * n_envs )) if seed: try: for i, pyenv in enumerate(eval_tf_env.pyenv.envs): pyenv.seed(seed * n_envs + i) except: pass elif 'Drunk' in env_name: # Just visualizes trajectories in drunk spider environment eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(env_name)) else: eval_tf_env = None if monitor: vid_path = os.path.join(root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): py_env = env_load_fn(env_name) tf_env = tf_py_environment.TFPyEnvironment(py_env) if seed: try: for i, pyenv in enumerate(tf_env.pyenv.envs): pyenv.seed(seed * n_envs + i) except: pass time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() logging.debug('obs spec: %s', observation_spec) logging.debug('action spec: %s', action_spec) # =====================================================================# # Setup agent class # # =====================================================================# if agent_class == wcpg_agent.WcpgAgent: alpha_spec = tensor_spec.BoundedTensorSpec(shape=(1,), dtype=tf.float32, minimum=0., maximum=1., name='alpha') input_tensor_spec = (observation_spec, action_spec, alpha_spec) critic_net = agents.DistributionalCriticNetwork( input_tensor_spec, preprocessing_layer_size=critic_preprocessing_layer_size, joint_fc_layer_params=critic_joint_fc_layers) actor_net = agents.WcpgActorNetwork((observation_spec, alpha_spec), action_spec) else: actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: logging.debug('Making SQRL agent') if lambda_schedule_nsteps > 0: lambda_update_every_nsteps = num_global_steps // lambda_schedule_nsteps step_size = (lambda_final - lambda_initial) / lambda_update_every_nsteps lambda_scheduler = lambda lam: common.periodically( body=lambda: tf.group(lam.assign(lam + step_size)), period=lambda_update_every_nsteps) else: lambda_scheduler = None safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) ts = target_safety thresholds = [ts, 0.5] sc_metrics = [tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5)] tf_agent = agent_class( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=debug_summaries, safety_pretraining=pretraining, train_critic_online=online_critic, initial_log_lambda=lambda_initial, log_lambda=(lambda_scheduler is None), lambda_scheduler=lambda_scheduler) elif agent_class is ensemble_sac_agent.EnsembleSacAgent: critic_nets, critic_optimizers = [critic_net], [tf.keras.optimizers.Adam(critic_learning_rate)] for _ in range(num_critics - 1): critic_nets.append(agents.CriticNetwork((observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers)) critic_optimizers.append(tf.keras.optimizers.Adam(critic_learning_rate)) tf_agent = agent_class( time_step_spec, action_spec, actor_network=actor_net, critic_networks=critic_nets, critic_optimizers=critic_optimizers, debug_summaries=debug_summaries ) else: # agent is either SacAgent or WcpgAgent logging.debug('critic input_tensor_spec: %s', critic_net.input_tensor_spec) tf_agent = agent_class( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=debug_summaries) tf_agent.initialize() # =====================================================================# # Setup replay buffer # # =====================================================================# collect_data_spec = tf_agent.collect_data_spec logging.debug('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. rb_size = rb_size or 1000000 replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=rb_size) logging.debug('RB capacity: %i', replay_buffer.capacity) logging.debug('ReplayBuffer Collect data spec: %s', collect_data_spec) if agent_class in SAFETY_AGENTS: sc_rb_size = sc_rb_size or num_eval_episodes * 500 sc_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=sc_rb_size, dataset_window_shift=1) num_episodes = tf_metrics.NumberOfEpisodes() num_env_steps = tf_metrics.EnvironmentSteps() return_metric = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ num_episodes, num_env_steps, return_metric, tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] if 'Minitaur' in env_name and not pretraining: goal_vel = gin.query_parameter("%GOAL_VELOCITY") early_termination_fn = train_utils.MinitaurTerminationFn( speed_metric=train_metrics[-2], total_falls_metric=train_metrics[-3], env_steps_metric=num_env_steps, goal_speed=goal_vel) if env_metric_factories: for env_metric in env_metric_factories: train_metrics.append(tf_py_metric.TFPyMetric(env_metric(tf_env.pyenv.envs))) if run_eval: eval_metrics.append(env_metric([env for env in eval_tf_env.pyenv._envs])) # =====================================================================# # Setup collect policies # # =====================================================================# if not online_critic: eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy if not pretraining and agent_class in SAFETY_AGENTS: collect_policy = tf_agent.safe_policy else: eval_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy collect_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy online_collect_policy = tf_agent.safe_policy # if pretraining else tf_agent.collect_policy if pretraining: online_collect_policy._training = False if not load_root_dir: initial_collect_policy = random_tf_policy.RandomTFPolicy(time_step_spec, action_spec) else: initial_collect_policy = collect_policy if agent_class == wcpg_agent.WcpgAgent: initial_collect_policy = agents.WcpgPolicyWrapper(initial_collect_policy) # =====================================================================# # Setup Checkpointing # # =====================================================================# train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') rb_checkpointer = common.Checkpointer( ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) if online_critic: online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer') online_rb_checkpointer = common.Checkpointer( ckpt_dir=online_rb_ckpt_dir, max_to_keep=1, replay_buffer=sc_buffer) # loads agent, replay buffer, and online sc/buffer if online_critic if load_root_dir: load_root_dir = os.path.expanduser(load_root_dir) load_train_dir = os.path.join(load_root_dir, 'train') misc.load_agent_ckpt(load_train_dir, tf_agent) if len(os.listdir(os.path.join(load_train_dir, 'replay_buffer'))) > 1: load_rb_ckpt_dir = os.path.join(load_train_dir, 'replay_buffer') misc.load_rb_ckpt(load_rb_ckpt_dir, replay_buffer) if online_critic: load_online_sc_ckpt_dir = os.path.join(load_root_dir, 'sc') load_online_rb_ckpt_dir = os.path.join(load_train_dir, 'online_replay_buffer') if osp.exists(load_online_rb_ckpt_dir): misc.load_rb_ckpt(load_online_rb_ckpt_dir, sc_buffer) if osp.exists(load_online_sc_ckpt_dir): misc.load_safety_critic_ckpt(load_online_sc_ckpt_dir, safety_critic_net) elif agent_class in SAFETY_AGENTS: offline_run = sorted(os.listdir(os.path.join(load_train_dir, 'offline')))[-1] load_sc_ckpt_dir = os.path.join(load_train_dir, 'offline', offline_run, 'safety_critic') if osp.exists(load_sc_ckpt_dir): sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=(512, 512), name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') sc_optimizer = tf.keras.optimizers.Adam(critic_learning_rate) _ = misc.load_safety_critic_ckpt( load_sc_ckpt_dir, safety_critic_net=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=sc_optimizer) tf_agent._safety_critic_network = sc_net_off tf_agent._target_safety_critic_network = target_sc_net_off tf_agent._safety_critic_optimizer = sc_optimizer else: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if online_critic: online_rb_checkpointer.initialize_or_restore() if agent_class in SAFETY_AGENTS: sc_dir = os.path.join(root_dir, 'sc') safety_critic_checkpointer = common.Checkpointer( ckpt_dir=sc_dir, safety_critic=tf_agent._safety_critic_network, # pylint: disable=protected-access target_safety_critic=tf_agent._target_safety_critic_network, optimizer=tf_agent._safety_critic_optimizer, global_step=global_step) if not (load_root_dir and not online_critic): safety_critic_checkpointer.initialize_or_restore() agent_observers = [replay_buffer.add_batch] + train_metrics collect_driver = collect_driver_class( tf_env, collect_policy, observers=agent_observers) collect_driver.run = common.function_in_tf1()(collect_driver.run) if online_critic: logging.debug('online driver class: %s', online_driver_class) online_agent_observers = [num_episodes, num_env_steps, sc_buffer.add_batch] online_driver = online_driver_class( sc_tf_env, online_collect_policy, observers=online_agent_observers, num_episodes=num_eval_episodes) online_driver.run = common.function_in_tf1()(online_driver.run) if eager_debug: tf.config.experimental_run_functions_eagerly(True) else: config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() if global_step == 0: logging.info('Performing initial collection ...') init_collect_observers = agent_observers if agent_class in SAFETY_AGENTS: init_collect_observers += [sc_buffer.add_batch] initial_collect_driver_class( tf_env, initial_collect_policy, observers=init_collect_observers).run() last_id = replay_buffer._get_last_id() # pylint: disable=protected-access logging.info('Data saved after initial collection: %d steps', last_id) if agent_class in SAFETY_AGENTS: last_id = sc_buffer._get_last_id() # pylint: disable=protected-access logging.debug('Data saved in sc_buffer after initial collection: %d steps', last_id) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 train_step = train_utils.get_train_step(tf_agent, replay_buffer, batch_size) if agent_class in SAFETY_AGENTS: critic_train_step = train_utils.get_critic_train_step( tf_agent, replay_buffer, sc_buffer, batch_size=batch_size, updating_sc=updating_sc, metrics=sc_metrics) if early_termination_fn is None: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss') if agent_class in SAFETY_AGENTS: resample_counter = collect_policy._resample_counter mean_resample_ac = tf.keras.metrics.Mean(name='mean_unsafe_ac_freq') sc_metrics.append(mean_resample_ac) if online_critic: logging.debug('starting safety critic pretraining') # don't fine-tune safety critic if global_step.numpy() == 0: for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() critic_results = [('sc_loss', sc_loss.numpy()), ('lambda_loss', lambda_loss.numpy())] for critic_metric in sc_metrics: res = critic_metric.result().numpy() if not res.shape: critic_results.append((critic_metric.name, res)) else: for r, thresh in zip(res, thresholds): name = '_'.join([critic_metric.name, str(thresh)]) critic_results.append((name, r)) critic_metric.reset_states() if train_metrics_callback: train_metrics_callback(collections.OrderedDict(critic_results), step=global_step.numpy()) logging.debug('Starting main train loop...') curr_ep = [] global_step_val = global_step.numpy() while global_step_val <= num_global_steps and not early_termination_fn(): start_time = time.time() # MEASURE ACTION RESAMPLING FREQUENCY if agent_class in SAFETY_AGENTS: if pretraining and global_step_val == num_global_steps // 2: if online_critic: online_collect_policy._training = True collect_policy._training = True if online_critic or collect_policy._training: mean_resample_ac(resample_counter.result()) resample_counter.reset() if time_step is None or time_step.is_last(): resample_ac_freq = mean_resample_ac.result() mean_resample_ac.reset_states() tf.compat.v2.summary.scalar( name='resample_ac_freq', data=resample_ac_freq, step=global_step) # RUN COLLECTION time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) # get last step taken by step_driver traj = replay_buffer._data_table.read(replay_buffer._get_last_id() % replay_buffer._capacity) curr_ep.append(traj) if time_step.is_last(): if agent_class in SAFETY_AGENTS: if time_step.observation['task_agn_rew']: if kstep_fail: # applies task agn rew. over last k steps for i, traj in enumerate(curr_ep[-kstep_fail:]): traj.observation['task_agn_rew'] = 1. sc_buffer.add_batch(traj) else: [sc_buffer.add_batch(traj) for traj in curr_ep] curr_ep = [] if agent_class == wcpg_agent.WcpgAgent: collect_policy._alpha = None # reset WCPG alpha if (global_step_val + 1) % log_interval == 0: logging.debug('policy eval: %4.2f sec', time.time() - start_time) # PERFORMS TRAIN STEP ON ALGORITHM (OFF-POLICY) for _ in range(train_steps_per_iteration): train_loss = train_step() mean_train_loss(train_loss.loss) current_step = global_step.numpy() total_loss = mean_train_loss.result() mean_train_loss.reset_states() if train_metrics_callback and current_step % summary_interval == 0: train_metrics_callback( collections.OrderedDict([(k, v.numpy()) for k, v in train_loss.extra._asdict().items()]), step=current_step) train_metrics_callback( {'train_loss': total_loss.numpy()}, step=current_step) # TRAIN AND/OR EVAL SAFETY CRITIC if agent_class in SAFETY_AGENTS and current_step % train_sc_interval == 0: if online_critic: batch_time_step = sc_tf_env.reset() # run online critic training collect & update batch_policy_state = online_collect_policy.get_initial_state( sc_tf_env.batch_size) online_driver.run(time_step=batch_time_step, policy_state=batch_policy_state) for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # log safety_critic loss results critic_results = [('sc_loss', sc_loss.numpy()), ('lambda_loss', lambda_loss.numpy())] metric_utils.log_metrics(sc_metrics) for critic_metric in sc_metrics: res = critic_metric.result().numpy() if not res.shape: critic_results.append((critic_metric.name, res)) else: for r, thresh in zip(res, thresholds): name = '_'.join([critic_metric.name, str(thresh)]) critic_results.append((name, r)) critic_metric.reset_states() if train_metrics_callback and current_step % summary_interval == 0: train_metrics_callback(collections.OrderedDict(critic_results), step=current_step) # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True logging.info('Loss diverged, critic_loss: %s, actor_loss: %s', train_loss.extra.critic_loss, train_loss.extra.actor_loss) break else: loss_divergence_counter = 0 time_acc += time.time() - start_time # LOGGING AND METRICS if current_step % log_interval == 0: metric_utils.log_metrics(train_metrics) logging.info('step = %d, loss = %f', current_step, total_loss) steps_per_sec = (current_step - timed_at_step) / time_acc logging.info('%4.2f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = current_step time_acc = 0 train_results = [] for metric in train_metrics[2:]: if isinstance(metric, (metrics.AverageEarlyFailureMetric, metrics.AverageFallenMetric, metrics.AverageSuccessMetric)): # Plot failure as a fn of return metric.tf_summaries( train_step=global_step, step_metrics=[num_env_steps, num_episodes, return_metric]) else: metric.tf_summaries( train_step=global_step, step_metrics=[num_env_steps, num_env_steps]) train_results.append((metric.name, metric.result().numpy())) if train_metrics_callback and current_step % summary_interval == 0: train_metrics_callback(collections.OrderedDict(train_results), step=global_step.numpy()) if current_step % train_checkpoint_interval == 0: train_checkpointer.save(global_step=current_step) if current_step % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=current_step) if agent_class in SAFETY_AGENTS: safety_critic_checkpointer.save(global_step=current_step) if online_critic: online_rb_checkpointer.save(global_step=current_step) if rb_checkpoint_interval and current_step % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=current_step) if wandb and current_step % eval_interval == 0 and "Drunk" in env_name: misc.record_point_mass_episode(eval_tf_env, eval_policy, current_step) if online_critic: misc.record_point_mass_episode(eval_tf_env, tf_agent.safe_policy, current_step, 'safe-trajectory') if run_eval and current_step % eval_interval == 0: eval_results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(eval_results, current_step) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries(train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and current_step % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action(monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug('saved rollout at timestep %d, rollout length: %d, %4.2f sec', current_step, ep_len, time.time() - monitor_start) global_step_val = current_step if early_termination_fn(): # Early stopped, save all checkpoints if not saved if global_step_val % train_checkpoint_interval != 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval != 0: policy_checkpointer.save(global_step=global_step_val) if agent_class in SAFETY_AGENTS: safety_critic_checkpointer.save(global_step=global_step_val) if online_critic: online_rb_checkpointer.save(global_step=global_step_val) if rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if not keep_rb_checkpoint: misc.cleanup_checkpoints(rb_ckpt_dir) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy())) return total_loss