Exemple #1
0
    def testCounterMetricIncrements(self):
        counter = py_metrics.CounterMetric()

        self.assertEqual(0, counter.result())
        counter()
        self.assertEqual(1, counter.result())
        counter()
        self.assertEqual(2, counter.result())
        counter.reset()
        self.assertEqual(0, counter.result())
        counter()
        self.assertEqual(1, counter.result())
Exemple #2
0
def train(
        root_dir,
        load_root_dir=None,
        env_load_fn=None,
        env_name=None,
        num_parallel_environments=1,  # pylint: disable=unused-argument
        agent_class=None,
        initial_collect_random=True,  # pylint: disable=unused-argument
        initial_collect_driver_class=None,
        collect_driver_class=None,
        num_global_steps=1000000,
        train_steps_per_iteration=1,
        train_metrics=None,
        # Safety Critic training args
        train_sc_steps=10,
        train_sc_interval=300,
        online_critic=False,
        # Params for eval
        run_eval=False,
        num_eval_episodes=30,
        eval_interval=1000,
        eval_metrics_callback=None,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        keep_rb_checkpoint=False,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        early_termination_fn=None,
        env_metric_factories=None):  # pylint: disable=unused-argument
    """A simple train and eval for SC-SAC."""

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    train_metrics = train_metrics or []

    if run_eval:
        eval_dir = os.path.join(root_dir, 'eval')
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes),
        ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = env_load_fn(env_name)
        if not isinstance(tf_env, tf_py_environment.TFPyEnvironment):
            tf_env = tf_py_environment.TFPyEnvironment(tf_env)

        if run_eval:
            eval_py_env = env_load_fn(env_name)
            eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        print('obs spec:', observation_spec)
        print('action spec:', action_spec)

        if online_critic:
            resample_metric = tf_py_metric.TfPyMetric(
                py_metrics.CounterMetric('unsafe_ac_samples'))
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   train_step_counter=global_step,
                                   resample_metric=resample_metric)
        else:
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   train_step_counter=global_step)

        tf_agent.initialize()

        # Make the replay buffer.
        collect_data_spec = tf_agent.collect_data_spec

        logging.info('Allocating replay buffer ...')
        # Add to replay buffer and other agent specific observers.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            collect_data_spec, max_length=1000000)
        logging.info('RB capacity: %i', replay_buffer.capacity)
        logging.info('ReplayBuffer Collect data spec: %s', collect_data_spec)

        agent_observers = [replay_buffer.add_batch]
        if online_critic:
            online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                collect_data_spec, max_length=10000)

            online_rb_ckpt_dir = os.path.join(train_dir,
                                              'online_replay_buffer')
            online_rb_checkpointer = common.Checkpointer(
                ckpt_dir=online_rb_ckpt_dir,
                max_to_keep=1,
                replay_buffer=online_replay_buffer)

            clear_rb = common.function(online_replay_buffer.clear)
            agent_observers.append(online_replay_buffer.add_batch)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

        if not online_critic:
            eval_policy = tf_agent.policy
        else:
            eval_policy = tf_agent._safe_policy  # pylint: disable=protected-access

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)
        if not online_critic:
            collect_policy = tf_agent.collect_policy
        else:
            collect_policy = tf_agent._safe_policy  # pylint: disable=protected-access

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        safety_critic_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'safety_critic'),
            safety_critic=tf_agent._safety_critic_network,  # pylint: disable=protected-access
            global_step=global_step)
        rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')
        rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir,
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        if load_root_dir:
            load_root_dir = os.path.expanduser(load_root_dir)
            load_train_dir = os.path.join(load_root_dir, 'train')
            misc.load_pi_ckpt(load_train_dir, tf_agent)  # loads tf_agent

        if load_root_dir is None:
            train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()
        safety_critic_checkpointer.initialize_or_restore()

        collect_driver = collect_driver_class(tf_env,
                                              collect_policy,
                                              observers=agent_observers +
                                              train_metrics)

        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

        if not rb_checkpointer.checkpoint_exists:
            logging.info('Performing initial collection ...')
            common.function(
                initial_collect_driver_class(tf_env,
                                             initial_collect_policy,
                                             observers=agent_observers +
                                             train_metrics).run)()
            last_id = replay_buffer._get_last_id()  # pylint: disable=protected-access
            logging.info('Data saved after initial collection: %d steps',
                         last_id)
            tf.print(
                replay_buffer._get_rows_for_id(last_id),  # pylint: disable=protected-access
                output_stream=logging.info)

        if run_eval:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='Metrics',
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics)
            if FLAGS.viz_pm:
                eval_fig_dir = osp.join(eval_dir, 'figs')
                if not tf.io.gfile.isdir(eval_fig_dir):
                    tf.io.gfile.makedirs(eval_fig_dir)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)
        if online_critic:
            online_dataset = online_replay_buffer.as_dataset(
                num_parallel_calls=3, num_steps=2).prefetch(3)
            online_iterator = iter(online_dataset)

            @common.function
            def critic_train_step():
                """Builds critic training step."""
                experience, buf_info = next(online_iterator)
                if env_name in [
                        'IndianWell', 'IndianWell2', 'IndianWell3',
                        'DrunkSpider', 'DrunkSpiderShort'
                ]:
                    safe_rew = experience.observation['task_agn_rew']
                else:
                    safe_rew = agents.process_replay_buffer(
                        online_replay_buffer, as_tensor=True)
                    safe_rew = tf.gather(safe_rew,
                                         tf.squeeze(buf_info.ids),
                                         axis=1)
                ret = tf_agent.train_sc(experience, safe_rew)
                clear_rb()
                return ret

        @common.function
        def train_step():
            experience, _ = next(iterator)
            ret = tf_agent.train(experience)
            return ret

        if not early_termination_fn:
            early_termination_fn = lambda: False

        loss_diverged = False
        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0
        mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss')
        if online_critic:
            mean_resample_ac = tf.keras.metrics.Mean(
                name='mean_unsafe_ac_samples')
            resample_metric.reset()

        while (global_step.numpy() <= num_global_steps
               and not early_termination_fn()):
            # Collect and train.
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            if online_critic:
                mean_resample_ac(resample_metric.result())
                resample_metric.reset()
                if time_step.is_last():
                    resample_ac_freq = mean_resample_ac.result()
                    mean_resample_ac.reset_states()
                    tf.compat.v2.summary.scalar(name='unsafe_ac_samples',
                                                data=resample_ac_freq,
                                                step=global_step)

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
                mean_train_loss(train_loss.loss)

            if online_critic:
                if global_step.numpy() % train_sc_interval == 0:
                    for _ in range(train_sc_steps):
                        sc_loss, lambda_loss = critic_train_step()  # pylint: disable=unused-variable

            total_loss = mean_train_loss.result()
            mean_train_loss.reset_states()
            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    loss_diverged = True
                    break
            else:
                loss_divergence_counter = 0

            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             total_loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)
                safety_critic_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                if online_critic:
                    online_rb_checkpointer.save(global_step=global_step_val)
                rb_checkpointer.save(global_step=global_step_val)

            if run_eval and global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
                if FLAGS.viz_pm:
                    savepath = 'step{}.png'.format(global_step_val)
                    savepath = osp.join(eval_fig_dir, savepath)
                    misc.record_episode_vis_summary(eval_tf_env, eval_policy,
                                                    savepath)

    if not keep_rb_checkpoint:
        misc.cleanup_checkpoints(rb_ckpt_dir)

    if loss_diverged:
        # Raise an error at the very end after the cleanup.
        raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
            total_loss, global_step.numpy()))

    return total_loss
Exemple #3
0
    def __init__(
            self,
            root_dir,
            env_name,
            num_iterations=200,
            max_episode_frames=108000,  # ALE frames
            terminal_on_life_loss=False,
            conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3),
                                                                  1)),
            fc_layer_params=(512, ),
            # Params for collect
            initial_collect_steps=80000,  # ALE frames
            epsilon_greedy=0.01,
            epsilon_decay_period=1000000,  # ALE frames
            replay_buffer_capacity=1000000,
            # Params for train
            train_steps_per_iteration=1000000,  # ALE frames
            update_period=16,  # ALE frames
            target_update_tau=1.0,
            target_update_period=32000,  # ALE frames
            batch_size=32,
            learning_rate=2.5e-4,
            n_step_update=2,
            gamma=0.99,
            reward_scale_factor=1.0,
            gradient_clipping=None,
            # Params for eval
            do_eval=True,
            eval_steps_per_iteration=500000,  # ALE frames
            eval_epsilon_greedy=0.001,
            # Params for checkpoints, summaries, and logging
            log_interval=1000,
            summary_interval=1000,
            summaries_flush_secs=10,
            debug_summaries=True,
            summarize_grads_and_vars=True,
            eval_metrics_callback=None):
        """A simple Atari train and eval for DQN.

    Args:
      root_dir: Directory to write log files to.
      env_name: Fully-qualified name of the Atari environment (i.e. Pong-v0).
      num_iterations: Number of train/eval iterations to run.
      max_episode_frames: Maximum length of a single episode, in ALE frames.
      terminal_on_life_loss: Whether to simulate an episode termination when a
        life is lost.
      conv_layer_params: Params for convolutional layers of QNetwork.
      fc_layer_params: Params for fully connected layers of QNetwork.
      initial_collect_steps: Number of frames to ALE frames to process before
        beginning to train. Since this is in ALE frames, there will be
        initial_collect_steps/4 items in the replay buffer when training starts.
      epsilon_greedy: Final epsilon value to decay to for training.
      epsilon_decay_period: Period over which to decay epsilon, from 1.0 to
        epsilon_greedy (defined above).
      replay_buffer_capacity: Maximum number of items to store in the replay
        buffer.
      train_steps_per_iteration: Number of ALE frames to run through for each
        iteration of training.
      update_period: Run a train operation every update_period ALE frames.
      target_update_tau: Coeffecient for soft target network updates (1.0 ==
        hard updates).
      target_update_period: Period, in ALE frames, to copy the live network to
        the target network.
      batch_size: Number of frames to include in each training batch.
      learning_rate: RMS optimizer learning rate.
      n_step_update: The number of steps to consider when computing TD error and
        TD loss. Applies standard single-step updates when set to 1.
      gamma: Discount for future rewards.
      reward_scale_factor: Scaling factor for rewards.
      gradient_clipping: Norm length to clip gradients.
      do_eval: If True, run an eval every iteration. If False, skip eval.
      eval_steps_per_iteration: Number of ALE frames to run through for each
        iteration of evaluation.
      eval_epsilon_greedy: Epsilon value to use for the evaluation policy (0 ==
        totally greedy policy).
      log_interval: Log stats to the terminal every log_interval training
        steps.
      summary_interval: Write TF summaries every summary_interval training
        steps.
      summaries_flush_secs: Flush summaries to disk every summaries_flush_secs
        seconds.
      debug_summaries: If True, write additional summaries for debugging (see
        dqn_agent for which summaries are written).
      summarize_grads_and_vars: Include gradients in summaries.
      eval_metrics_callback: A callback function that takes (metric_dict,
        global_step) as parameters. Called after every eval with the results of
        the evaluation.
    """
        self._update_period = update_period / ATARI_FRAME_SKIP
        self._train_steps_per_iteration = (train_steps_per_iteration /
                                           ATARI_FRAME_SKIP)
        self._do_eval = do_eval
        self._eval_steps_per_iteration = eval_steps_per_iteration / ATARI_FRAME_SKIP
        self._eval_epsilon_greedy = eval_epsilon_greedy
        self._initial_collect_steps = initial_collect_steps / ATARI_FRAME_SKIP
        self._summary_interval = summary_interval
        self._num_iterations = num_iterations
        self._log_interval = log_interval
        self._eval_metrics_callback = eval_metrics_callback

        with gin.unlock_config():
            gin.bind_parameter(('tf_agents.environments.atari_preprocessing.'
                                'AtariPreprocessing.terminal_on_life_loss'),
                               terminal_on_life_loss)

        root_dir = os.path.expanduser(root_dir)
        train_dir = os.path.join(root_dir, 'train')
        eval_dir = os.path.join(root_dir, 'eval')

        train_summary_writer = tf.compat.v2.summary.create_file_writer(
            train_dir, flush_millis=summaries_flush_secs * 1000)
        train_summary_writer.set_as_default()
        self._train_summary_writer = train_summary_writer

        self._eval_summary_writer = None
        if self._do_eval:
            self._eval_summary_writer = tf.compat.v2.summary.create_file_writer(
                eval_dir, flush_millis=summaries_flush_secs * 1000)
            self._eval_metrics = [
                py_metrics.AverageReturnMetric(name='PhaseAverageReturn',
                                               buffer_size=np.inf),
                py_metrics.AverageEpisodeLengthMetric(
                    name='PhaseAverageEpisodeLength', buffer_size=np.inf),
            ]

        self._global_step = tf.compat.v1.train.get_or_create_global_step()
        with tf.compat.v2.summary.record_if(lambda: tf.math.equal(
                self._global_step % self._summary_interval, 0)):
            self._env = suite_atari.load(
                env_name,
                max_episode_steps=max_episode_frames / ATARI_FRAME_SKIP,
                gym_env_wrappers=suite_atari.
                DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)
            self._env = batched_py_environment.BatchedPyEnvironment(
                [self._env])

            observation_spec = tensor_spec.from_spec(
                self._env.observation_spec())
            time_step_spec = ts.time_step_spec(observation_spec)
            action_spec = tensor_spec.from_spec(self._env.action_spec())

            with tf.device('/cpu:0'):
                epsilon = tf.compat.v1.train.polynomial_decay(
                    1.0,
                    self._global_step,
                    epsilon_decay_period / ATARI_FRAME_SKIP /
                    self._update_period,
                    end_learning_rate=epsilon_greedy)

            with tf.device('/gpu:0'):
                optimizer = tf.compat.v1.train.RMSPropOptimizer(
                    learning_rate=learning_rate,
                    decay=0.95,
                    momentum=0.0,
                    epsilon=0.00001,
                    centered=True)
                categorical_q_net = AtariCategoricalQNetwork(
                    observation_spec,
                    action_spec,
                    conv_layer_params=conv_layer_params,
                    fc_layer_params=fc_layer_params)
                agent = categorical_dqn_agent.CategoricalDqnAgent(
                    time_step_spec,
                    action_spec,
                    categorical_q_network=categorical_q_net,
                    optimizer=optimizer,
                    epsilon_greedy=epsilon,
                    n_step_update=n_step_update,
                    target_update_tau=target_update_tau,
                    target_update_period=(target_update_period /
                                          ATARI_FRAME_SKIP /
                                          self._update_period),
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=self._global_step)

                self._collect_policy = py_tf_policy.PyTFPolicy(
                    agent.collect_policy)

                if self._do_eval:
                    self._eval_policy = py_tf_policy.PyTFPolicy(
                        epsilon_greedy_policy.EpsilonGreedyPolicy(
                            policy=agent.policy,
                            epsilon=self._eval_epsilon_greedy))

                py_observation_spec = self._env.observation_spec()
                py_time_step_spec = ts.time_step_spec(py_observation_spec)
                py_action_spec = policy_step.PolicyStep(
                    self._env.action_spec())
                data_spec = trajectory.from_transition(py_time_step_spec,
                                                       py_action_spec,
                                                       py_time_step_spec)
                self._replay_buffer = py_hashed_replay_buffer.PyHashedReplayBuffer(
                    data_spec=data_spec, capacity=replay_buffer_capacity)

            with tf.device('/cpu:0'):
                ds = self._replay_buffer.as_dataset(
                    sample_batch_size=batch_size, num_steps=n_step_update + 1)
                ds = ds.prefetch(4)
                ds = ds.apply(
                    tf.data.experimental.prefetch_to_device('/gpu:0'))

            with tf.device('/gpu:0'):
                self._ds_itr = tf.compat.v1.data.make_one_shot_iterator(ds)
                experience = self._ds_itr.get_next()
                self._train_op = agent.train(experience)

                self._env_steps_metric = py_metrics.EnvironmentSteps()
                self._step_metrics = [
                    py_metrics.NumberOfEpisodes(),
                    self._env_steps_metric,
                ]
                self._train_metrics = self._step_metrics + [
                    py_metrics.AverageReturnMetric(buffer_size=10),
                    py_metrics.AverageEpisodeLengthMetric(buffer_size=10),
                ]
                # The _train_phase_metrics average over an entire train iteration,
                # rather than the rolling average of the last 10 episodes.
                self._train_phase_metrics = [
                    py_metrics.AverageReturnMetric(name='PhaseAverageReturn',
                                                   buffer_size=np.inf),
                    py_metrics.AverageEpisodeLengthMetric(
                        name='PhaseAverageEpisodeLength', buffer_size=np.inf),
                ]
                self._iteration_metric = py_metrics.CounterMetric(
                    name='Iteration')

                # Summaries written from python should run every time they are
                # generated.
                with tf.compat.v2.summary.record_if(True):
                    self._steps_per_second_ph = tf.compat.v1.placeholder(
                        tf.float32, shape=(), name='steps_per_sec_ph')
                    self._steps_per_second_summary = tf.compat.v2.summary.scalar(
                        name='global_steps_per_sec',
                        data=self._steps_per_second_ph,
                        step=self._global_step)

                    for metric in self._train_metrics:
                        metric.tf_summaries(train_step=self._global_step,
                                            step_metrics=self._step_metrics)

                    for metric in self._train_phase_metrics:
                        metric.tf_summaries(
                            train_step=self._global_step,
                            step_metrics=(self._iteration_metric, ))
                    self._iteration_metric.tf_summaries(
                        train_step=self._global_step)

                    if self._do_eval:
                        with self._eval_summary_writer.as_default():
                            for metric in self._eval_metrics:
                                metric.tf_summaries(
                                    train_step=self._global_step,
                                    step_metrics=(self._iteration_metric, ))

                self._train_checkpointer = common.Checkpointer(
                    ckpt_dir=train_dir,
                    agent=agent,
                    global_step=self._global_step,
                    optimizer=optimizer,
                    metrics=metric_utils.MetricsGroup(
                        self._train_metrics + self._train_phase_metrics +
                        [self._iteration_metric], 'train_metrics'))
                self._policy_checkpointer = common.Checkpointer(
                    ckpt_dir=os.path.join(train_dir, 'policy'),
                    policy=agent.policy,
                    global_step=self._global_step)
                self._rb_checkpointer = common.Checkpointer(
                    ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
                    max_to_keep=1,
                    replay_buffer=self._replay_buffer)

                self._init_agent_op = agent.initialize()
Exemple #4
0
def train_eval(
        load_root_dir,
        env_load_fn=None,
        gym_env_wrappers=[],
        monitor=False,
        env_name=None,
        agent_class=None,
        train_metrics_callback=None,
        # SacAgent args
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256),
        # Safety Critic training args
        safety_critic_joint_fc_layers=None,
        safety_critic_lr=3e-4,
        safety_critic_bias_init_val=None,
        safety_critic_kernel_scale=None,
        n_envs=None,
        target_safety=0.2,
        fail_weight=None,
        # Params for train
        num_global_steps=10000,
        batch_size=256,
        # Params for eval
        run_eval=False,
        eval_metrics=[],
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        summary_interval=1000,
        monitor_interval=5000,
        summaries_flush_secs=10,
        debug_summaries=False,
        seed=None):

    if isinstance(agent_class, str):
        assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(
            agent_class)
        agent_class = ALGOS.get(agent_class)

    train_ckpt_dir = osp.join(load_root_dir, 'train')
    rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer')

    py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)

    if monitor:
        vid_path = os.path.join(load_root_dir, 'rollouts')
        monitor_env_wrapper = misc.monitor_freq(1, vid_path)
        monitor_env = gym.make(env_name)
        for wrapper in gym_env_wrappers:
            monitor_env = wrapper(monitor_env)
        monitor_env = monitor_env_wrapper(monitor_env)
        # auto_reset must be False to ensure Monitor works correctly
        monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

    if run_eval:
        eval_dir = os.path.join(load_root_dir, 'eval')
        n_envs = n_envs or num_eval_episodes
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(prefix='EvalMetrics',
                                           buffer_size=num_eval_episodes,
                                           batch_size=n_envs),
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='EvalMetrics',
                buffer_size=num_eval_episodes,
                batch_size=n_envs)
        ] + [
            tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name))
            for m in eval_metrics
        ]
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            seeds = [seed * n_envs + i for i in range(n_envs)]
            try:
                eval_tf_env.pyenv.seed(seeds)
            except:
                pass

    global_step = tf.compat.v1.train.get_or_create_global_step()

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)

    critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
        safety_critic_net = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=critic_joint_fc_layers)
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               safety_critic_network=safety_critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)
    else:
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)

    collect_data_spec = tf_agent.collect_data_spec
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=1000000)
    replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer)

    tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent)
    if agent_class in SAFETY_AGENTS:
        target_safety = target_safety or tf_agent._target_safety
    loaded_train_steps = global_step.numpy()
    logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir,
                 loaded_train_steps)
    global_step.assign(0)
    tf.summary.experimental.set_step(global_step)

    thresholds = [target_safety, 0.5]
    sc_metrics = [
        tf.keras.metrics.AUC(name='safety_critic_auc'),
        tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                        threshold=0.5),
        tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                       thresholds=thresholds),
        tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                        thresholds=thresholds),
        tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                       thresholds=thresholds),
        tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                        thresholds=thresholds)
    ]

    if seed:
        tf.compat.v1.set_random_seed(seed)

    summaries_flush_secs = 10
    timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S')
    offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp)
    config_saver = gin.tf.GinConfigSaverHook(offline_train_dir,
                                             summarize_config=True)
    tf.function(config_saver.after_create_session)()

    sc_summary_writer = tf.compat.v2.summary.create_file_writer(
        offline_train_dir, flush_millis=summaries_flush_secs * 1000)
    sc_summary_writer.set_as_default()

    if safety_critic_kernel_scale is not None:
        ki = tf.compat.v1.variance_scaling_initializer(
            scale=safety_critic_kernel_scale,
            mode='fan_in',
            distribution='truncated_normal')
    else:
        ki = tf.compat.v1.keras.initializers.VarianceScaling(
            scale=1. / 3., mode='fan_in', distribution='uniform')

    if safety_critic_bias_init_val is not None:
        bi = tf.constant_initializer(safety_critic_bias_init_val)
    else:
        bi = None
    sc_net_off = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=safety_critic_joint_fc_layers,
        kernel_initializer=ki,
        value_bias_initializer=bi,
        name='SafetyCriticOffline')
    sc_net_off.create_variables()
    target_sc_net_off = common.maybe_copy_target_network_with_checks(
        sc_net_off, None, 'TargetSafetyCriticNetwork')
    optimizer = tf.keras.optimizers.Adam(safety_critic_lr)
    sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic')
    sc_checkpointer = common.Checkpointer(
        ckpt_dir=sc_net_off_ckpt_dir,
        safety_critic=sc_net_off,
        target_safety_critic=target_sc_net_off,
        optimizer=optimizer,
        global_step=global_step,
        max_to_keep=5)
    sc_checkpointer.initialize_or_restore()

    resample_counter = py_metrics.CounterMetric('ActionResampleCounter')
    eval_policy = agents.SafeActorPolicyRSVar(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=actor_net,
        safety_critic_network=sc_net_off,
        safety_threshold=target_safety,
        resample_counter=resample_counter,
        training=True)

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       num_steps=2,
                                       sample_batch_size=batch_size //
                                       2).prefetch(3)
    data = iter(dataset)
    full_data = replay_buffer.gather_all()

    fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool)
    fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, fail_mask), full_data)
    init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data)
    before_fail_mask = tf.roll(fail_mask, [-1], axis=[1])
    after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1])
    before_fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data)
    after_init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, after_init_mask), full_data)

    filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask))
    filter_mask = tf.pad(
        filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]])
    n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy()

    failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec,
        batch_size=1,
        max_length=n_failures,
        dataset_window_shift=1)
    data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask)

    sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3,
                                               sample_batch_size=batch_size //
                                               2,
                                               num_steps=2).prefetch(3)
    neg_data = iter(sc_dataset_neg)

    get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0]
    eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step,
                                after_init_step, get_action)

    losses = []
    mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss')
    target_update = train_utils.get_target_updater(sc_net_off,
                                                   target_sc_net_off)

    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        while global_step.numpy() < num_global_steps:
            pos_experience, _ = next(data)
            neg_experience, _ = next(neg_data)
            exp = data_utils.concat_batches(pos_experience, neg_experience,
                                            collect_data_spec)
            boundary_mask = tf.logical_not(exp.is_boundary()[:, 0])
            exp = nest_utils.fast_map_structure(
                lambda *x: tf.boolean_mask(*x, boundary_mask), exp)
            safe_rew = exp.observation['task_agn_rew'][:, 1]
            if fail_weight:
                weights = tf.where(tf.cast(safe_rew, tf.bool),
                                   fail_weight / 0.5, (1 - fail_weight) / 0.5)
            else:
                weights = None
            train_loss, sc_loss, lam_loss = train_step(
                exp,
                safe_rew,
                tf_agent,
                sc_net=sc_net_off,
                target_sc_net=target_sc_net_off,
                metrics=sc_metrics,
                weights=weights,
                target_safety=target_safety,
                optimizer=optimizer,
                target_update=target_update,
                debug_summaries=debug_summaries)
            global_step.assign_add(1)
            global_step_val = global_step.numpy()
            losses.append(
                (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy()))
            mean_loss(train_loss)
            with tf.name_scope('Losses'):
                tf.compat.v2.summary.scalar(name='sc_loss',
                                            data=sc_loss,
                                            step=global_step_val)
                tf.compat.v2.summary.scalar(name='lam_loss',
                                            data=lam_loss,
                                            step=global_step_val)
                if global_step_val % summary_interval == 0:
                    tf.compat.v2.summary.scalar(name=mean_loss.name,
                                                data=mean_loss.result(),
                                                step=global_step_val)
            if global_step_val % summary_interval == 0:
                with tf.name_scope('Metrics'):
                    for metric in sc_metrics:
                        if len(tf.squeeze(metric.result()).shape) == 0:
                            tf.compat.v2.summary.scalar(name=metric.name,
                                                        data=metric.result(),
                                                        step=global_step_val)
                        else:
                            fmt_str = '_{}'.format(thresholds[0])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[0],
                                step=global_step_val)
                            fmt_str = '_{}'.format(thresholds[1])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[1],
                                step=global_step_val)
                        metric.reset_states()
            if global_step_val % eval_interval == 0:
                eval_sc(sc_net_off, step=global_step_val)
                if run_eval:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix='EvalMetrics',
                    )
                    if train_metrics_callback is not None:
                        train_metrics_callback(results, global_step_val)
                    metric_utils.log_metrics(eval_metrics)
                    with eval_summary_writer.as_default():
                        for eval_metric in eval_metrics[2:]:
                            eval_metric.tf_summaries(
                                train_step=global_step,
                                step_metrics=eval_metrics[:2])
            if monitor and global_step_val % monitor_interval == 0:
                monitor_time_step = monitor_py_env.reset()
                monitor_policy_state = eval_policy.get_initial_state(1)
                ep_len = 0
                monitor_start = time.time()
                while not monitor_time_step.is_last():
                    monitor_action = eval_policy.action(
                        monitor_time_step, monitor_policy_state)
                    action, monitor_policy_state = monitor_action.action, monitor_action.state
                    monitor_time_step = monitor_py_env.step(action)
                    ep_len += 1
                logging.debug(
                    'saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                    global_step_val, ep_len,
                    time.time() - monitor_start)

            if global_step_val % train_checkpoint_interval == 0:
                sc_checkpointer.save(global_step=global_step_val)