def testCounterMetricIncrements(self): counter = py_metrics.CounterMetric() self.assertEqual(0, counter.result()) counter() self.assertEqual(1, counter.result()) counter() self.assertEqual(2, counter.result()) counter.reset() self.assertEqual(0, counter.result()) counter() self.assertEqual(1, counter.result())
def train( root_dir, load_root_dir=None, env_load_fn=None, env_name=None, num_parallel_environments=1, # pylint: disable=unused-argument agent_class=None, initial_collect_random=True, # pylint: disable=unused-argument initial_collect_driver_class=None, collect_driver_class=None, num_global_steps=1000000, train_steps_per_iteration=1, train_metrics=None, # Safety Critic training args train_sc_steps=10, train_sc_interval=300, online_critic=False, # Params for eval run_eval=False, num_eval_episodes=30, eval_interval=1000, eval_metrics_callback=None, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, keep_rb_checkpoint=False, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, early_termination_fn=None, env_metric_factories=None): # pylint: disable=unused-argument """A simple train and eval for SC-SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() train_metrics = train_metrics or [] if run_eval: eval_dir = os.path.join(root_dir, 'eval') eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = env_load_fn(env_name) if not isinstance(tf_env, tf_py_environment.TFPyEnvironment): tf_env = tf_py_environment.TFPyEnvironment(tf_env) if run_eval: eval_py_env = env_load_fn(env_name) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() print('obs spec:', observation_spec) print('action spec:', action_spec) if online_critic: resample_metric = tf_py_metric.TfPyMetric( py_metrics.CounterMetric('unsafe_ac_samples')) tf_agent = agent_class(time_step_spec, action_spec, train_step_counter=global_step, resample_metric=resample_metric) else: tf_agent = agent_class(time_step_spec, action_spec, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. collect_data_spec = tf_agent.collect_data_spec logging.info('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, max_length=1000000) logging.info('RB capacity: %i', replay_buffer.capacity) logging.info('ReplayBuffer Collect data spec: %s', collect_data_spec) agent_observers = [replay_buffer.add_batch] if online_critic: online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, max_length=10000) online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer') online_rb_checkpointer = common.Checkpointer( ckpt_dir=online_rb_ckpt_dir, max_to_keep=1, replay_buffer=online_replay_buffer) clear_rb = common.function(online_replay_buffer.clear) agent_observers.append(online_replay_buffer.add_batch) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] if not online_critic: eval_policy = tf_agent.policy else: eval_policy = tf_agent._safe_policy # pylint: disable=protected-access initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) if not online_critic: collect_policy = tf_agent.collect_policy else: collect_policy = tf_agent._safe_policy # pylint: disable=protected-access train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) safety_critic_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'safety_critic'), safety_critic=tf_agent._safety_critic_network, # pylint: disable=protected-access global_step=global_step) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) if load_root_dir: load_root_dir = os.path.expanduser(load_root_dir) load_train_dir = os.path.join(load_root_dir, 'train') misc.load_pi_ckpt(load_train_dir, tf_agent) # loads tf_agent if load_root_dir is None: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() safety_critic_checkpointer.initialize_or_restore() collect_driver = collect_driver_class(tf_env, collect_policy, observers=agent_observers + train_metrics) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if not rb_checkpointer.checkpoint_exists: logging.info('Performing initial collection ...') common.function( initial_collect_driver_class(tf_env, initial_collect_policy, observers=agent_observers + train_metrics).run)() last_id = replay_buffer._get_last_id() # pylint: disable=protected-access logging.info('Data saved after initial collection: %d steps', last_id) tf.print( replay_buffer._get_rows_for_id(last_id), # pylint: disable=protected-access output_stream=logging.info) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if FLAGS.viz_pm: eval_fig_dir = osp.join(eval_dir, 'figs') if not tf.io.gfile.isdir(eval_fig_dir): tf.io.gfile.makedirs(eval_fig_dir) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2).prefetch(3) iterator = iter(dataset) if online_critic: online_dataset = online_replay_buffer.as_dataset( num_parallel_calls=3, num_steps=2).prefetch(3) online_iterator = iter(online_dataset) @common.function def critic_train_step(): """Builds critic training step.""" experience, buf_info = next(online_iterator) if env_name in [ 'IndianWell', 'IndianWell2', 'IndianWell3', 'DrunkSpider', 'DrunkSpiderShort' ]: safe_rew = experience.observation['task_agn_rew'] else: safe_rew = agents.process_replay_buffer( online_replay_buffer, as_tensor=True) safe_rew = tf.gather(safe_rew, tf.squeeze(buf_info.ids), axis=1) ret = tf_agent.train_sc(experience, safe_rew) clear_rb() return ret @common.function def train_step(): experience, _ = next(iterator) ret = tf_agent.train(experience) return ret if not early_termination_fn: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss') if online_critic: mean_resample_ac = tf.keras.metrics.Mean( name='mean_unsafe_ac_samples') resample_metric.reset() while (global_step.numpy() <= num_global_steps and not early_termination_fn()): # Collect and train. start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) if online_critic: mean_resample_ac(resample_metric.result()) resample_metric.reset() if time_step.is_last(): resample_ac_freq = mean_resample_ac.result() mean_resample_ac.reset_states() tf.compat.v2.summary.scalar(name='unsafe_ac_samples', data=resample_ac_freq, step=global_step) for _ in range(train_steps_per_iteration): train_loss = train_step() mean_train_loss(train_loss.loss) if online_critic: if global_step.numpy() % train_sc_interval == 0: for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # pylint: disable=unused-variable total_loss = mean_train_loss.result() mean_train_loss.reset_states() # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True break else: loss_divergence_counter = 0 time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), total_loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) safety_critic_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: if online_critic: online_rb_checkpointer.save(global_step=global_step_val) rb_checkpointer.save(global_step=global_step_val) if run_eval and global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) if FLAGS.viz_pm: savepath = 'step{}.png'.format(global_step_val) savepath = osp.join(eval_fig_dir, savepath) misc.record_episode_vis_summary(eval_tf_env, eval_policy, savepath) if not keep_rb_checkpoint: misc.cleanup_checkpoints(rb_ckpt_dir) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy())) return total_loss
def __init__( self, root_dir, env_name, num_iterations=200, max_episode_frames=108000, # ALE frames terminal_on_life_loss=False, conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)), fc_layer_params=(512, ), # Params for collect initial_collect_steps=80000, # ALE frames epsilon_greedy=0.01, epsilon_decay_period=1000000, # ALE frames replay_buffer_capacity=1000000, # Params for train train_steps_per_iteration=1000000, # ALE frames update_period=16, # ALE frames target_update_tau=1.0, target_update_period=32000, # ALE frames batch_size=32, learning_rate=2.5e-4, n_step_update=2, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval do_eval=True, eval_steps_per_iteration=500000, # ALE frames eval_epsilon_greedy=0.001, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=True, eval_metrics_callback=None): """A simple Atari train and eval for DQN. Args: root_dir: Directory to write log files to. env_name: Fully-qualified name of the Atari environment (i.e. Pong-v0). num_iterations: Number of train/eval iterations to run. max_episode_frames: Maximum length of a single episode, in ALE frames. terminal_on_life_loss: Whether to simulate an episode termination when a life is lost. conv_layer_params: Params for convolutional layers of QNetwork. fc_layer_params: Params for fully connected layers of QNetwork. initial_collect_steps: Number of frames to ALE frames to process before beginning to train. Since this is in ALE frames, there will be initial_collect_steps/4 items in the replay buffer when training starts. epsilon_greedy: Final epsilon value to decay to for training. epsilon_decay_period: Period over which to decay epsilon, from 1.0 to epsilon_greedy (defined above). replay_buffer_capacity: Maximum number of items to store in the replay buffer. train_steps_per_iteration: Number of ALE frames to run through for each iteration of training. update_period: Run a train operation every update_period ALE frames. target_update_tau: Coeffecient for soft target network updates (1.0 == hard updates). target_update_period: Period, in ALE frames, to copy the live network to the target network. batch_size: Number of frames to include in each training batch. learning_rate: RMS optimizer learning rate. n_step_update: The number of steps to consider when computing TD error and TD loss. Applies standard single-step updates when set to 1. gamma: Discount for future rewards. reward_scale_factor: Scaling factor for rewards. gradient_clipping: Norm length to clip gradients. do_eval: If True, run an eval every iteration. If False, skip eval. eval_steps_per_iteration: Number of ALE frames to run through for each iteration of evaluation. eval_epsilon_greedy: Epsilon value to use for the evaluation policy (0 == totally greedy policy). log_interval: Log stats to the terminal every log_interval training steps. summary_interval: Write TF summaries every summary_interval training steps. summaries_flush_secs: Flush summaries to disk every summaries_flush_secs seconds. debug_summaries: If True, write additional summaries for debugging (see dqn_agent for which summaries are written). summarize_grads_and_vars: Include gradients in summaries. eval_metrics_callback: A callback function that takes (metric_dict, global_step) as parameters. Called after every eval with the results of the evaluation. """ self._update_period = update_period / ATARI_FRAME_SKIP self._train_steps_per_iteration = (train_steps_per_iteration / ATARI_FRAME_SKIP) self._do_eval = do_eval self._eval_steps_per_iteration = eval_steps_per_iteration / ATARI_FRAME_SKIP self._eval_epsilon_greedy = eval_epsilon_greedy self._initial_collect_steps = initial_collect_steps / ATARI_FRAME_SKIP self._summary_interval = summary_interval self._num_iterations = num_iterations self._log_interval = log_interval self._eval_metrics_callback = eval_metrics_callback with gin.unlock_config(): gin.bind_parameter(('tf_agents.environments.atari_preprocessing.' 'AtariPreprocessing.terminal_on_life_loss'), terminal_on_life_loss) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() self._train_summary_writer = train_summary_writer self._eval_summary_writer = None if self._do_eval: self._eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) self._eval_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if(lambda: tf.math.equal( self._global_step % self._summary_interval, 0)): self._env = suite_atari.load( env_name, max_episode_steps=max_episode_frames / ATARI_FRAME_SKIP, gym_env_wrappers=suite_atari. DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING) self._env = batched_py_environment.BatchedPyEnvironment( [self._env]) observation_spec = tensor_spec.from_spec( self._env.observation_spec()) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.from_spec(self._env.action_spec()) with tf.device('/cpu:0'): epsilon = tf.compat.v1.train.polynomial_decay( 1.0, self._global_step, epsilon_decay_period / ATARI_FRAME_SKIP / self._update_period, end_learning_rate=epsilon_greedy) with tf.device('/gpu:0'): optimizer = tf.compat.v1.train.RMSPropOptimizer( learning_rate=learning_rate, decay=0.95, momentum=0.0, epsilon=0.00001, centered=True) categorical_q_net = AtariCategoricalQNetwork( observation_spec, action_spec, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params) agent = categorical_dqn_agent.CategoricalDqnAgent( time_step_spec, action_spec, categorical_q_network=categorical_q_net, optimizer=optimizer, epsilon_greedy=epsilon, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=(target_update_period / ATARI_FRAME_SKIP / self._update_period), gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=self._global_step) self._collect_policy = py_tf_policy.PyTFPolicy( agent.collect_policy) if self._do_eval: self._eval_policy = py_tf_policy.PyTFPolicy( epsilon_greedy_policy.EpsilonGreedyPolicy( policy=agent.policy, epsilon=self._eval_epsilon_greedy)) py_observation_spec = self._env.observation_spec() py_time_step_spec = ts.time_step_spec(py_observation_spec) py_action_spec = policy_step.PolicyStep( self._env.action_spec()) data_spec = trajectory.from_transition(py_time_step_spec, py_action_spec, py_time_step_spec) self._replay_buffer = py_hashed_replay_buffer.PyHashedReplayBuffer( data_spec=data_spec, capacity=replay_buffer_capacity) with tf.device('/cpu:0'): ds = self._replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=n_step_update + 1) ds = ds.prefetch(4) ds = ds.apply( tf.data.experimental.prefetch_to_device('/gpu:0')) with tf.device('/gpu:0'): self._ds_itr = tf.compat.v1.data.make_one_shot_iterator(ds) experience = self._ds_itr.get_next() self._train_op = agent.train(experience) self._env_steps_metric = py_metrics.EnvironmentSteps() self._step_metrics = [ py_metrics.NumberOfEpisodes(), self._env_steps_metric, ] self._train_metrics = self._step_metrics + [ py_metrics.AverageReturnMetric(buffer_size=10), py_metrics.AverageEpisodeLengthMetric(buffer_size=10), ] # The _train_phase_metrics average over an entire train iteration, # rather than the rolling average of the last 10 episodes. self._train_phase_metrics = [ py_metrics.AverageReturnMetric(name='PhaseAverageReturn', buffer_size=np.inf), py_metrics.AverageEpisodeLengthMetric( name='PhaseAverageEpisodeLength', buffer_size=np.inf), ] self._iteration_metric = py_metrics.CounterMetric( name='Iteration') # Summaries written from python should run every time they are # generated. with tf.compat.v2.summary.record_if(True): self._steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') self._steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=self._steps_per_second_ph, step=self._global_step) for metric in self._train_metrics: metric.tf_summaries(train_step=self._global_step, step_metrics=self._step_metrics) for metric in self._train_phase_metrics: metric.tf_summaries( train_step=self._global_step, step_metrics=(self._iteration_metric, )) self._iteration_metric.tf_summaries( train_step=self._global_step) if self._do_eval: with self._eval_summary_writer.as_default(): for metric in self._eval_metrics: metric.tf_summaries( train_step=self._global_step, step_metrics=(self._iteration_metric, )) self._train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=agent, global_step=self._global_step, optimizer=optimizer, metrics=metric_utils.MetricsGroup( self._train_metrics + self._train_phase_metrics + [self._iteration_metric], 'train_metrics')) self._policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=agent.policy, global_step=self._global_step) self._rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=self._replay_buffer) self._init_agent_op = agent.initialize()
def train_eval( load_root_dir, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args safety_critic_joint_fc_layers=None, safety_critic_lr=3e-4, safety_critic_bias_init_val=None, safety_critic_kernel_scale=None, n_envs=None, target_safety=0.2, fail_weight=None, # Params for train num_global_steps=10000, batch_size=256, # Params for eval run_eval=False, eval_metrics=[], num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, debug_summaries=False, seed=None): if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format( agent_class) agent_class = ALGOS.get(agent_class) train_ckpt_dir = osp.join(load_root_dir, 'train') rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer') py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(py_env) if monitor: vid_path = os.path.join(load_root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) if run_eval: eval_dir = os.path.join(load_root_dir, 'eval') n_envs = n_envs or num_eval_episodes eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric( prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs) ] + [ tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name)) for m in eval_metrics ] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: eval_tf_env.pyenv.seed(seeds) except: pass global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=False) else: tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=False) collect_data_spec = tf_agent.collect_data_spec replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=1000000) replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer) tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent) if agent_class in SAFETY_AGENTS: target_safety = target_safety or tf_agent._target_safety loaded_train_steps = global_step.numpy() logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir, loaded_train_steps) global_step.assign(0) tf.summary.experimental.set_step(global_step) thresholds = [target_safety, 0.5] sc_metrics = [ tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds) ] if seed: tf.compat.v1.set_random_seed(seed) summaries_flush_secs = 10 timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S') offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp) config_saver = gin.tf.GinConfigSaverHook(offline_train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() sc_summary_writer = tf.compat.v2.summary.create_file_writer( offline_train_dir, flush_millis=summaries_flush_secs * 1000) sc_summary_writer.set_as_default() if safety_critic_kernel_scale is not None: ki = tf.compat.v1.variance_scaling_initializer( scale=safety_critic_kernel_scale, mode='fan_in', distribution='truncated_normal') else: ki = tf.compat.v1.keras.initializers.VarianceScaling( scale=1. / 3., mode='fan_in', distribution='uniform') if safety_critic_bias_init_val is not None: bi = tf.constant_initializer(safety_critic_bias_init_val) else: bi = None sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=safety_critic_joint_fc_layers, kernel_initializer=ki, value_bias_initializer=bi, name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') optimizer = tf.keras.optimizers.Adam(safety_critic_lr) sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic') sc_checkpointer = common.Checkpointer( ckpt_dir=sc_net_off_ckpt_dir, safety_critic=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=optimizer, global_step=global_step, max_to_keep=5) sc_checkpointer.initialize_or_restore() resample_counter = py_metrics.CounterMetric('ActionResampleCounter') eval_policy = agents.SafeActorPolicyRSVar( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_net, safety_critic_network=sc_net_off, safety_threshold=target_safety, resample_counter=resample_counter, training=True) dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2, sample_batch_size=batch_size // 2).prefetch(3) data = iter(dataset) full_data = replay_buffer.gather_all() fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool) fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, fail_mask), full_data) init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data) before_fail_mask = tf.roll(fail_mask, [-1], axis=[1]) after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1]) before_fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data) after_init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, after_init_mask), full_data) filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask)) filter_mask = tf.pad( filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]]) n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy() failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=n_failures, dataset_window_shift=1) data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask) sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size // 2, num_steps=2).prefetch(3) neg_data = iter(sc_dataset_neg) get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0] eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step, after_init_step, get_action) losses = [] mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss') target_update = train_utils.get_target_updater(sc_net_off, target_sc_net_off) with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): while global_step.numpy() < num_global_steps: pos_experience, _ = next(data) neg_experience, _ = next(neg_data) exp = data_utils.concat_batches(pos_experience, neg_experience, collect_data_spec) boundary_mask = tf.logical_not(exp.is_boundary()[:, 0]) exp = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, boundary_mask), exp) safe_rew = exp.observation['task_agn_rew'][:, 1] if fail_weight: weights = tf.where(tf.cast(safe_rew, tf.bool), fail_weight / 0.5, (1 - fail_weight) / 0.5) else: weights = None train_loss, sc_loss, lam_loss = train_step( exp, safe_rew, tf_agent, sc_net=sc_net_off, target_sc_net=target_sc_net_off, metrics=sc_metrics, weights=weights, target_safety=target_safety, optimizer=optimizer, target_update=target_update, debug_summaries=debug_summaries) global_step.assign_add(1) global_step_val = global_step.numpy() losses.append( (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy())) mean_loss(train_loss) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='sc_loss', data=sc_loss, step=global_step_val) tf.compat.v2.summary.scalar(name='lam_loss', data=lam_loss, step=global_step_val) if global_step_val % summary_interval == 0: tf.compat.v2.summary.scalar(name=mean_loss.name, data=mean_loss.result(), step=global_step_val) if global_step_val % summary_interval == 0: with tf.name_scope('Metrics'): for metric in sc_metrics: if len(tf.squeeze(metric.result()).shape) == 0: tf.compat.v2.summary.scalar(name=metric.name, data=metric.result(), step=global_step_val) else: fmt_str = '_{}'.format(thresholds[0]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[0], step=global_step_val) fmt_str = '_{}'.format(thresholds[1]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[1], step=global_step_val) metric.reset_states() if global_step_val % eval_interval == 0: eval_sc(sc_net_off, step=global_step_val) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries( train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and global_step_val % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action( monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug( 'saved rollout at timestep %d, rollout length: %d, %4.2f sec', global_step_val, ep_len, time.time() - monitor_start) if global_step_val % train_checkpoint_interval == 0: sc_checkpointer.save(global_step=global_step_val)