Ejemplo n.º 1
0
    def testReset(self):
        batched_avg_return_metric = batched_py_metric.BatchedPyMetric(
            py_metrics.AverageReturnMetric)
        tf_avg_return_metric = tf_py_metric.TFPyMetric(
            batched_avg_return_metric)

        deps = []
        # run one episode
        for i in range(3):
            with tf.control_dependencies(deps):
                traj = tf_avg_return_metric(self._ts[i])
                deps = tf.nest.flatten(traj)

        # reset
        with tf.control_dependencies(deps):
            reset_op = tf_avg_return_metric.reset()
            deps = [reset_op]

        # run second episode
        for i in range(3, 6):
            with tf.control_dependencies(deps):
                traj = tf_avg_return_metric(self._ts[i])
                deps = tf.nest.flatten(traj)

        # Test result is the reward for the second episode.
        with tf.control_dependencies(deps):
            result = tf_avg_return_metric.result()

        result_ = self.evaluate(result)
        self.assertEqual(result_, 13)
Ejemplo n.º 2
0
    def testMetricPrefix(self):
        batched_avg_return_metric = batched_py_metric.BatchedPyMetric(
            py_metrics.AverageReturnMetric, prefix='CustomPrefix')
        self.assertEqual(batched_avg_return_metric.prefix, 'CustomPrefix')

        tf_avg_return_metric = tf_py_metric.TFPyMetric(
            batched_avg_return_metric)
        self.assertEqual(tf_avg_return_metric._prefix, 'CustomPrefix')
Ejemplo n.º 3
0
    def _build_metrics(self, buffer_size=10, batch_size=None):
        python_metrics = [
            tf_py_metric.TFPyMetric(
                py_metrics.AverageReturnMetric(buffer_size=buffer_size,
                                               batch_size=batch_size)),
            tf_py_metric.TFPyMetric(
                py_metrics.AverageEpisodeLengthMetric(buffer_size=buffer_size,
                                                      batch_size=batch_size)),
        ]
        if batch_size is None:
            batch_size = 1
        tensorflow_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=buffer_size,
                                           batch_size=batch_size),
            tf_metrics.AverageEpisodeLengthMetric(buffer_size=buffer_size,
                                                  batch_size=batch_size),
        ]

        return python_metrics, tensorflow_metrics
Ejemplo n.º 4
0
 def testMetricIsComputedCorrectly(self, num_time_steps, expected_reward):
   batched_avg_return_metric = batched_py_metric.BatchedPyMetric(
       py_metrics.AverageReturnMetric)
   tf_avg_return_metric = tf_py_metric.TFPyMetric(batched_avg_return_metric)
   deps = []
   for i in range(num_time_steps):
     with tf.control_dependencies(deps):
       traj = tf_avg_return_metric(self._ts[i])
       deps = nest.flatten(traj)
   with tf.control_dependencies(deps):
     result = tf_avg_return_metric.result()
   result_ = self.evaluate(result)
   self.assertEqual(result_, expected_reward)
Ejemplo n.º 5
0
def train(
        root_dir,
        load_root_dir=None,
        env_load_fn=None,
        env_name=None,
        num_parallel_environments=1,  # pylint: disable=unused-argument
        agent_class=None,
        initial_collect_random=True,  # pylint: disable=unused-argument
        initial_collect_driver_class=None,
        collect_driver_class=None,
        num_global_steps=1000000,
        train_steps_per_iteration=1,
        train_metrics=None,
        # Safety Critic training args
        train_sc_steps=10,
        train_sc_interval=300,
        online_critic=False,
        # Params for eval
        run_eval=False,
        num_eval_episodes=30,
        eval_interval=1000,
        eval_metrics_callback=None,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        keep_rb_checkpoint=False,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        early_termination_fn=None,
        env_metric_factories=None):  # pylint: disable=unused-argument
    """A simple train and eval for SC-SAC."""

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    train_metrics = train_metrics or []

    if run_eval:
        eval_dir = os.path.join(root_dir, 'eval')
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes),
        ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = env_load_fn(env_name)
        if not isinstance(tf_env, tf_py_environment.TFPyEnvironment):
            tf_env = tf_py_environment.TFPyEnvironment(tf_env)

        if run_eval:
            eval_py_env = env_load_fn(env_name)
            eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        print('obs spec:', observation_spec)
        print('action spec:', action_spec)

        if online_critic:
            resample_metric = tf_py_metric.TfPyMetric(
                py_metrics.CounterMetric('unsafe_ac_samples'))
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   train_step_counter=global_step,
                                   resample_metric=resample_metric)
        else:
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   train_step_counter=global_step)

        tf_agent.initialize()

        # Make the replay buffer.
        collect_data_spec = tf_agent.collect_data_spec

        logging.info('Allocating replay buffer ...')
        # Add to replay buffer and other agent specific observers.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            collect_data_spec, max_length=1000000)
        logging.info('RB capacity: %i', replay_buffer.capacity)
        logging.info('ReplayBuffer Collect data spec: %s', collect_data_spec)

        agent_observers = [replay_buffer.add_batch]
        if online_critic:
            online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                collect_data_spec, max_length=10000)

            online_rb_ckpt_dir = os.path.join(train_dir,
                                              'online_replay_buffer')
            online_rb_checkpointer = common.Checkpointer(
                ckpt_dir=online_rb_ckpt_dir,
                max_to_keep=1,
                replay_buffer=online_replay_buffer)

            clear_rb = common.function(online_replay_buffer.clear)
            agent_observers.append(online_replay_buffer.add_batch)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

        if not online_critic:
            eval_policy = tf_agent.policy
        else:
            eval_policy = tf_agent._safe_policy  # pylint: disable=protected-access

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)
        if not online_critic:
            collect_policy = tf_agent.collect_policy
        else:
            collect_policy = tf_agent._safe_policy  # pylint: disable=protected-access

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        safety_critic_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'safety_critic'),
            safety_critic=tf_agent._safety_critic_network,  # pylint: disable=protected-access
            global_step=global_step)
        rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')
        rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir,
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        if load_root_dir:
            load_root_dir = os.path.expanduser(load_root_dir)
            load_train_dir = os.path.join(load_root_dir, 'train')
            misc.load_pi_ckpt(load_train_dir, tf_agent)  # loads tf_agent

        if load_root_dir is None:
            train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()
        safety_critic_checkpointer.initialize_or_restore()

        collect_driver = collect_driver_class(tf_env,
                                              collect_policy,
                                              observers=agent_observers +
                                              train_metrics)

        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

        if not rb_checkpointer.checkpoint_exists:
            logging.info('Performing initial collection ...')
            common.function(
                initial_collect_driver_class(tf_env,
                                             initial_collect_policy,
                                             observers=agent_observers +
                                             train_metrics).run)()
            last_id = replay_buffer._get_last_id()  # pylint: disable=protected-access
            logging.info('Data saved after initial collection: %d steps',
                         last_id)
            tf.print(
                replay_buffer._get_rows_for_id(last_id),  # pylint: disable=protected-access
                output_stream=logging.info)

        if run_eval:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='Metrics',
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics)
            if FLAGS.viz_pm:
                eval_fig_dir = osp.join(eval_dir, 'figs')
                if not tf.io.gfile.isdir(eval_fig_dir):
                    tf.io.gfile.makedirs(eval_fig_dir)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)
        if online_critic:
            online_dataset = online_replay_buffer.as_dataset(
                num_parallel_calls=3, num_steps=2).prefetch(3)
            online_iterator = iter(online_dataset)

            @common.function
            def critic_train_step():
                """Builds critic training step."""
                experience, buf_info = next(online_iterator)
                if env_name in [
                        'IndianWell', 'IndianWell2', 'IndianWell3',
                        'DrunkSpider', 'DrunkSpiderShort'
                ]:
                    safe_rew = experience.observation['task_agn_rew']
                else:
                    safe_rew = agents.process_replay_buffer(
                        online_replay_buffer, as_tensor=True)
                    safe_rew = tf.gather(safe_rew,
                                         tf.squeeze(buf_info.ids),
                                         axis=1)
                ret = tf_agent.train_sc(experience, safe_rew)
                clear_rb()
                return ret

        @common.function
        def train_step():
            experience, _ = next(iterator)
            ret = tf_agent.train(experience)
            return ret

        if not early_termination_fn:
            early_termination_fn = lambda: False

        loss_diverged = False
        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0
        mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss')
        if online_critic:
            mean_resample_ac = tf.keras.metrics.Mean(
                name='mean_unsafe_ac_samples')
            resample_metric.reset()

        while (global_step.numpy() <= num_global_steps
               and not early_termination_fn()):
            # Collect and train.
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            if online_critic:
                mean_resample_ac(resample_metric.result())
                resample_metric.reset()
                if time_step.is_last():
                    resample_ac_freq = mean_resample_ac.result()
                    mean_resample_ac.reset_states()
                    tf.compat.v2.summary.scalar(name='unsafe_ac_samples',
                                                data=resample_ac_freq,
                                                step=global_step)

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
                mean_train_loss(train_loss.loss)

            if online_critic:
                if global_step.numpy() % train_sc_interval == 0:
                    for _ in range(train_sc_steps):
                        sc_loss, lambda_loss = critic_train_step()  # pylint: disable=unused-variable

            total_loss = mean_train_loss.result()
            mean_train_loss.reset_states()
            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    loss_diverged = True
                    break
            else:
                loss_divergence_counter = 0

            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             total_loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)
                safety_critic_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                if online_critic:
                    online_rb_checkpointer.save(global_step=global_step_val)
                rb_checkpointer.save(global_step=global_step_val)

            if run_eval and global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
                if FLAGS.viz_pm:
                    savepath = 'step{}.png'.format(global_step_val)
                    savepath = osp.join(eval_fig_dir, savepath)
                    misc.record_episode_vis_summary(eval_tf_env, eval_policy,
                                                    savepath)

    if not keep_rb_checkpoint:
        misc.cleanup_checkpoints(rb_ckpt_dir)

    if loss_diverged:
        # Raise an error at the very end after the cleanup.
        raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
            total_loss, global_step.numpy()))

    return total_loss
    def __init__(self,
                 root_dir,
                 env_load_fn=suite_gym.load,
                 env_name='CartPole-v0',
                 num_parallel_environments=1,
                 agent_class=None,
                 num_eval_episodes=30,
                 write_summaries=True,
                 summaries_flush_secs=10,
                 eval_metrics_callback=None,
                 env_metric_factories=None):
        """Evaluate policy checkpoints as they are produced.

    Args:
      root_dir: Main directory for experiment files.
      env_load_fn: Function to load the environment specified by env_name.
      env_name: Name of environment to evaluate in.
      num_parallel_environments: Number of environments to evaluate on in
        parallel.
      agent_class: TFAgent class to instantiate for evaluation.
      num_eval_episodes: Number of episodes to average evaluation over.
      write_summaries: Whether to write summaries to the file system.
      summaries_flush_secs: How frequently to flush summaries (in seconds).
      eval_metrics_callback: A function that will be called with evaluation
        results for every checkpoint.
      env_metric_factories: An iterable of metric factories. Use this for eval
        metrics that needs access to the evaluated environment. A metric
        factory is a function that takes an eviornment and buffer_size as
        keyword arguments and returns an instance of py_metric.

    Raises:
      ValueError: when num_parallel_environments > num_eval_episodes or
        agent_class is not set
    """
        if not agent_class:
            raise ValueError(
                'The `agent_class` parameter of Evaluator must be set.')
        if num_parallel_environments > num_eval_episodes:
            raise ValueError(
                'num_parallel_environments should not be greater than '
                'num_eval_episodes')

        self._num_eval_episodes = num_eval_episodes
        self._eval_metrics_callback = eval_metrics_callback
        # Flag that controls eval cycle. If set, evaluation will exit eval loop
        # before the max checkpoint number is reached.
        self._terminate_early = False

        # Save root dir to self so derived classes have access to it.
        self._root_dir = os.path.expanduser(root_dir)
        train_dir = os.path.join(self._root_dir, 'train')
        self._eval_dir = os.path.join(self._root_dir, 'eval')

        self._global_step = tf.compat.v1.train.get_or_create_global_step()

        self._env_name = env_name
        if num_parallel_environments == 1:
            eval_env = env_load_fn(env_name)
        else:
            eval_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments)

        if isinstance(eval_env, py_environment.PyEnvironment):
            self._eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env)
            self._eval_py_env = eval_env
        else:
            self._eval_tf_env = eval_env
            self._eval_py_env = None  # Can't generically convert to PyEnvironment.

        self._eval_metrics = [
            tf_metrics.AverageReturnMetric(
                buffer_size=self._num_eval_episodes,
                batch_size=self._eval_tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=self._num_eval_episodes,
                batch_size=self._eval_tf_env.batch_size),
        ]
        if env_metric_factories:
            if not self._eval_py_env:
                raise ValueError(
                    'The `env_metric_factories` parameter of Evaluator '
                    'can only be used with a PyEnvironment environment.')
            for metric_factory in env_metric_factories:
                py_metric = metric_factory(environment=self._eval_py_env,
                                           buffer_size=self._num_eval_episodes)
                self._eval_metrics.append(tf_py_metric.TFPyMetric(py_metric))

        if write_summaries:
            self._eval_summary_writer = tf.compat.v2.summary.create_file_writer(
                self._eval_dir, flush_millis=summaries_flush_secs * 1000)
            self._eval_summary_writer.set_as_default()
        else:
            self._eval_summary_writer = None

        environment_specs.set_observation_spec(
            self._eval_tf_env.observation_spec())
        environment_specs.set_action_spec(self._eval_tf_env.action_spec())

        # Agent params configured with gin.
        self._agent = agent_class(self._eval_tf_env.time_step_spec(),
                                  self._eval_tf_env.action_spec())

        self._eval_policy = greedy_policy.GreedyPolicy(self._agent.policy)
        self._eval_policy.action = common.function(self._eval_policy.action)

        # Run the agent on dummy data to force instantiation of the network. Keras
        # doesn't create variables until you first use the layer. This is needed
        # for checkpoint restoration to work.
        dummy_obs = tensor_spec.sample_spec_nest(
            self._eval_tf_env.observation_spec(),
            outer_dims=(self._eval_tf_env.batch_size, ))
        self._eval_policy.action(
            ts.restart(dummy_obs, batch_size=self._eval_tf_env.batch_size),
            self._eval_policy.get_initial_state(self._eval_tf_env.batch_size))

        self._policy_checkpoint = tf.train.Checkpoint(
            policy=self._agent.policy, global_step=self._global_step)
        self._policy_checkpoint_dir = os.path.join(train_dir, 'policy')
Ejemplo n.º 7
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        num_iterations=1000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]
    eval_summary_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create the environment.
        tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
        eval_py_env = suite_mujoco.load(env_name)

        # Get the data specs from the environment
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=normal_projection_net)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        eval_py_policy = py_tf_policy.PyTFPolicy(
            greedy_policy.GreedyPolicy(tf_agent.policy))

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
            tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
        ]

        collect_policy = tf_agent.collect_policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=5 * batch_size,
            num_steps=2).apply(tf.data.experimental.unbatch()).filter(
                _filter_invalid_transition).batch(batch_size).prefetch(
                    batch_size * 5)
        dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset)
        trajectories, unused_info = dataset_iterator.get_next()
        train_op = tf_agent.train(trajectories)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        with tf.compat.v1.Session() as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)

            # Initialize training.
            sess.run(dataset_iterator.initializer)
            common.initialize_uninitialized_variables(sess)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            global_step_val = sess.run(global_step)

            if global_step_val == 0:
                # Initial eval of randomly initialized policy
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    callback=eval_metrics_callback,
                    log=True,
                )
                sess.run(eval_summary_flush_op)

                # Run initial collect.
                logging.info('Global step %d: Running initial collect op.',
                             global_step_val)
                sess.run(initial_collect_op)

                # Checkpoint the initial replay buffer contents.
                rb_checkpointer.save(global_step=global_step_val)

                logging.info('Finished initial collect.')
            else:
                logging.info('Global step %d: Skipping initial collect op.',
                             global_step_val)

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    total_loss, _ = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_flush_op)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
Ejemplo n.º 8
0
def train_eval(

        ##############################################
        # types of params:
        # 0: specific to algorithm (gin file 0)
        # 1: specific to environment (gin file 1)
        # 2: specific to experiment (gin file 2 + command line)

        # Note: there are other important params
        # in eg ModelDistributionNetwork that the gin files specify
        # like sparse vs dense rewards, latent dimensions, etc.
        ##############################################

    # basic params for running/logging experiment
    root_dir,  # 2
        experiment_name,  # 2
        num_iterations=int(1e7),  # 2
        seed=1,  # 2
        gpu_allow_growth=False,  # 2
        gpu_memory_limit=None,  # 2
        verbose=True,  # 2
        policy_checkpoint_freq_in_iter=100,  # policies needed for future eval                             # 2
        train_checkpoint_freq_in_iter=0,  #default don't save                                              # 2
        rb_checkpoint_freq_in_iter=0,  #default don't save                                                 # 2
        logging_freq_in_iter=10,  # printing to terminal                                                   # 2
        summary_freq_in_iter=10,  # saving to tb                                                           # 2
        num_images_per_summary=2,  # 2
        summaries_flush_secs=10,  # 2
        max_episode_len_override=None,  # 2
        num_trials_to_render=1,  # 2

        # environment, action mode, etc.
    env_name='HalfCheetah-v2',  # 1
        action_repeat=1,  # 1
        action_mode='joint_position',  # joint_position or joint_delta_position                           # 1
        double_camera=False,  # camera input                                                               # 1
        universe='gym',  # default
        task_reward_dim=1,  # default

        # dims for all networks
    actor_fc_layers=(256, 256),  # 1
        critic_obs_fc_layers=None,  # 1
        critic_action_fc_layers=None,  # 1
        critic_joint_fc_layers=(256, 256),  # 1
        num_repeat_when_concatenate=None,  # 1

        # networks
    critic_input='state',  # 0
        actor_input='state',  # 0

        # specifying tasks and eval
    episodes_per_trial=1,  # 2
        num_train_tasks=10,  # 2
        num_eval_tasks=10,  # 2
        num_eval_trials=10,  # 2
        eval_interval=10,  # 2
        eval_on_holdout_tasks=True,  # 2

        # data collection/buffer
    init_collect_trials_per_task=None,  # 2
        collect_trials_per_task=None,  # 2
        num_tasks_to_collect_per_iter=5,  # 2
        replay_buffer_capacity=int(1e5),  # 2

        # training
    init_model_train_ratio=0.8,  # 2
        model_train_ratio=1,  # 2
        model_train_freq=1,  # 2
        ac_train_ratio=1,  # 2
        ac_train_freq=1,  # 2
        num_tasks_per_train=5,  # 2
        train_trials_per_task=5,  # 2
        model_bs_in_steps=256,  # 2
        ac_bs_in_steps=128,  # 2

        # default AC learning rates, gamma, etc.
    target_update_tau=0.005,
        target_update_period=1,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        model_learning_rate=1e-4,
        td_errors_loss_fn=functools.partial(
            tf.compat.v1.losses.mean_squared_error, weights=0.5),
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        log_image_strips=False,
        stop_model_training=1E10,
        eval_only=False,  # evaluate checkpoints ONLY
        log_image_observations=False,
        load_offline_data=False,  # whether to use offline data
        offline_data_dir=None,  # replay buffer's dir
        offline_episode_len=None,  # episode len of episodes stored in rb
        offline_ratio=0,  # ratio of data that is from offline buffer
):

    g = tf.Graph()

    # register all gym envs
    max_steps_dict = {
        "HalfCheetahVel-v0": 50,
        "SawyerReach-v0": 40,
        "SawyerReachMT-v0": 40,
        "SawyerPeg-v0": 40,
        "SawyerPegMT-v0": 40,
        "SawyerPegMT4box-v0": 40,
        "SawyerShelfMT-v0": 40,
        "SawyerKitchenMT-v0": 40,
        "SawyerShelfMT-v2": 40,
        "SawyerButtons-v0": 40,
    }
    if max_episode_len_override:
        max_steps_dict[env_name] = max_episode_len_override
    register_all_gym_envs(max_steps_dict)

    # set max_episode_len based on our env
    max_episode_len = max_steps_dict[env_name]

    ######################################################
    # Calculate additional params
    ######################################################

    # convert to number of steps
    env_steps_per_trial = episodes_per_trial * max_episode_len
    real_env_steps_per_trial = episodes_per_trial * (max_episode_len + 1)
    env_steps_per_iter = num_tasks_to_collect_per_iter * collect_trials_per_task * env_steps_per_trial
    per_task_collect_steps = collect_trials_per_task * env_steps_per_trial

    # initial collect + train
    init_collect_env_steps = num_train_tasks * init_collect_trials_per_task * env_steps_per_trial
    init_model_train_steps = int(init_collect_env_steps *
                                 init_model_train_ratio)

    # collect + train
    collect_env_steps_per_iter = num_tasks_to_collect_per_iter * per_task_collect_steps
    model_train_steps_per_iter = int(env_steps_per_iter * model_train_ratio)
    ac_train_steps_per_iter = int(env_steps_per_iter * ac_train_ratio)

    # other
    global_steps_per_iter = collect_env_steps_per_iter + model_train_steps_per_iter + ac_train_steps_per_iter
    sample_episodes_per_task = train_trials_per_task * episodes_per_trial  # number of episodes to sample from each replay
    model_bs_in_trials = model_bs_in_steps // real_env_steps_per_trial

    # assertions that make sure parameters make sense
    assert model_bs_in_trials > 0, "model batch size need to be at least as big as one full real trial"
    assert num_tasks_to_collect_per_iter <= num_train_tasks, "when sampling replace=False"
    assert num_tasks_per_train * train_trials_per_task >= model_bs_in_trials, "not enough data for one batch model train"
    assert num_tasks_per_train * train_trials_per_task * env_steps_per_trial >= ac_bs_in_steps, "not enough data for one batch ac train"

    ######################################################
    # Print a summary of params
    ######################################################
    MELD_summary_string = f"""\n\n\n
==============================================================
==============================================================
  \n
  MELD algorithm summary:

  * each trial consists of {episodes_per_trial} episodes
  * episode length: {max_episode_len}, trial length: {env_steps_per_trial}
  * {num_train_tasks} train tasks, {num_eval_tasks} eval tasks, hold-out: {eval_on_holdout_tasks}
  * environment: {env_name}
  
  For each of {num_train_tasks} tasks:
    Do {init_collect_trials_per_task} trials of initial collect
  (total {init_collect_env_steps} env steps)
  
  Do {init_model_train_steps} steps of initial model training
    
  For i in range(inf):
    For each of {num_tasks_to_collect_per_iter} randomly selected tasks:
      Do {collect_trials_per_task} trials of collect
    (which is {collect_trials_per_task*env_steps_per_trial} env steps per task)
    (for a total of {num_tasks_to_collect_per_iter*collect_trials_per_task*env_steps_per_trial} env steps in the iteration)
    
    if i % model_train_freq(={model_train_freq}):
      Do {model_train_steps_per_iter} steps of model training
        - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials.
        - pick randomly {model_bs_in_trials} trials, train model on whole trials.
    
    if i % ac_train_freq(={ac_train_freq}):
      Do {ac_train_steps_per_iter} steps of ac training
        - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials.
        - pick randomly {ac_bs_in_steps} transitions, not including between trial transitions, 
          to train ac.
  
  
  * Other important params:
  Evaluate policy every {eval_interval} iters, equivalent to {global_steps_per_iter*eval_interval/1000:.1f}k global steps
  Average evaluation across {num_eval_trials} trials
  Save summary to tensorboard every {summary_freq_in_iter} iters, equivalent to {global_steps_per_iter*summary_freq_in_iter/1000:.1f}k global steps
  Checkpoint:
   - training checkpoint every {train_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*train_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint
   - policy checkpoint every {policy_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*policy_checkpoint_freq_in_iter//1000}k global steps, keep all checkpoints
   - replay buffer checkpoint every {rb_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*rb_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint
    
  \n
=============================================================
=============================================================
  """

    print(MELD_summary_string)
    time.sleep(1)

    ######################################################
    # Seed + name + GPU configs + directories for saving
    ######################################################
    np.random.seed(int(seed))
    experiment_name += "_seed" + str(seed)

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpu_allow_growth:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    if gpu_memory_limit:
        for gpu in gpus:
            tf.config.experimental.set_virtual_device_configuration(
                gpu, [
                    tf.config.experimental.VirtualDeviceConfiguration(
                        memory_limit=gpu_memory_limit)
                ])

    train_eval_dir = get_train_eval_dir(root_dir, universe, env_name,
                                        experiment_name)
    train_dir = os.path.join(train_eval_dir, 'train')
    eval_dir = os.path.join(train_eval_dir, 'eval')
    eval_dir_2 = os.path.join(train_eval_dir, 'eval2')

    ######################################################
    # Train and Eval Summary Writers
    ######################################################
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_summary_flush_op = eval_summary_writer.flush()

    eval_logger = Logger(eval_dir_2)

    ######################################################
    # Train and Eval metrics
    ######################################################
    eval_buffer_size = num_eval_trials * episodes_per_trial * max_episode_len  # across all eval trials in each evaluation
    eval_metrics = []
    for position in range(
            episodes_per_trial
    ):  # have metrics for each episode position, to track whether it is learning
        eval_metrics_pos = [
            py_metrics.AverageReturnMetric(name='c_AverageReturnEval_' +
                                           str(position),
                                           buffer_size=eval_buffer_size),
            py_metrics.AverageEpisodeLengthMetric(
                name='f_AverageEpisodeLengthEval_' + str(position),
                buffer_size=eval_buffer_size),
            custom_metrics.AverageScoreMetric(
                name="d_AverageScoreMetricEval_" + str(position),
                buffer_size=eval_buffer_size),
        ]
        eval_metrics.extend(eval_metrics_pos)

    train_buffer_size = num_train_tasks * episodes_per_trial
    train_metrics = [
        tf_metrics.NumberOfEpisodes(name='NumberOfEpisodes'),
        tf_metrics.EnvironmentSteps(name='EnvironmentSteps'),
        tf_py_metric.TFPyMetric(
            py_metrics.AverageReturnMetric(name="a_AverageReturnTrain",
                                           buffer_size=train_buffer_size)),
        tf_py_metric.TFPyMetric(
            py_metrics.AverageEpisodeLengthMetric(
                name="e_AverageEpisodeLengthTrain",
                buffer_size=train_buffer_size)),
        tf_py_metric.TFPyMetric(
            custom_metrics.AverageScoreMetric(name="b_AverageScoreTrain",
                                              buffer_size=train_buffer_size)),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step(
    )  # will be use to record number of model grad steps + ac grad steps + env_step

    log_cond = get_log_condition_tensor(
        global_step, init_collect_trials_per_task, env_steps_per_trial,
        num_train_tasks, init_model_train_steps, collect_trials_per_task,
        num_tasks_to_collect_per_iter, model_train_steps_per_iter,
        ac_train_steps_per_iter, summary_freq_in_iter, eval_interval)

    with tf.compat.v2.summary.record_if(log_cond):

        ######################################################
        # Create env
        ######################################################
        py_env, eval_py_env, train_tasks, eval_tasks = load_environments(
            universe,
            action_mode,
            env_name=env_name,
            observations_whitelist=['state', 'pixels', "env_info"],
            action_repeat=action_repeat,
            num_train_tasks=num_train_tasks,
            num_eval_tasks=num_eval_tasks,
            eval_on_holdout_tasks=eval_on_holdout_tasks,
            return_multiple_tasks=True,
        )
        override_reward_func = None
        if load_offline_data:
            py_env.set_task_dict(train_tasks)
            override_reward_func = py_env.override_reward_func

        tf_env = tf_py_environment.TFPyEnvironment(py_env, isolation=True)

        # Get data specs from env
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()
        original_control_timestep = get_control_timestep(eval_py_env)

        # fps
        control_timestep = original_control_timestep * float(action_repeat)
        render_fps = int(np.round(1.0 / original_control_timestep))

        ######################################################
        # Latent variable model
        ######################################################
        if verbose:
            print("-- start constructing model networks --")

        model_net = ModelDistributionNetwork(
            double_camera=double_camera,
            observation_spec=observation_spec,
            num_repeat_when_concatenate=num_repeat_when_concatenate,
            task_reward_dim=task_reward_dim,
            episodes_per_trial=episodes_per_trial,
            max_episode_len=max_episode_len
        )  # rest of arguments provided via gin

        if verbose:
            print("-- finish constructing AC networks --")

        ######################################################
        # Compressor Network for Actor/Critic
        # The model's compressor is also used by the AC
        # compressor function: images --> features
        ######################################################

        compressor_net = model_net.compressor

        ######################################################
        # Specs for Actor and Critic
        ######################################################
        if actor_input == 'state':
            actor_state_size = observation_spec['state'].shape[0]
        elif actor_input == 'latentSample':
            actor_state_size = model_net.state_size
        elif actor_input == "latentDistribution":
            actor_state_size = 2 * model_net.state_size  # mean and (diagonal) variance of gaussian, of two latents
        else:
            raise NotImplementedError
        actor_input_spec = tensor_spec.TensorSpec((actor_state_size, ),
                                                  dtype=tf.float32)

        if critic_input == 'state':
            critic_state_size = observation_spec['state'].shape[0]
        elif critic_input == 'latentSample':
            critic_state_size = model_net.state_size
        elif critic_input == "latentDistribution":
            critic_state_size = 2 * model_net.state_size  # mean and (diagonal) variance of gaussian, of two latents
        else:
            raise NotImplementedError
        critic_input_spec = tensor_spec.TensorSpec((critic_state_size, ),
                                                   dtype=tf.float32)

        ######################################################
        # Actor and Critic Networks
        ######################################################
        if verbose:
            print("-- start constructing Actor and Critic networks --")

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            actor_input_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
        )

        critic_net = critic_network.CriticNetwork(
            (critic_input_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        if verbose:
            print("-- finish constructing AC networks --")
            print("-- start constructing agent --")

        ######################################################
        # Create the agent
        ######################################################

        which_posterior_overwrite = None
        which_reward_overwrite = None

        meld_agent = MeldAgent(
            # specs
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            # step counter
            train_step_counter=
            global_step,  # will count number of model training steps
            # networks
            actor_network=actor_net,
            critic_network=critic_net,
            model_network=model_net,
            compressor_network=compressor_net,
            # optimizers
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            model_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=model_learning_rate),
            # target update
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            # inputs
            critic_input=critic_input,
            actor_input=actor_input,
            # bs stuff
            model_batch_size=model_bs_in_steps,
            ac_batch_size=ac_bs_in_steps,
            # other
            num_tasks_per_train=num_tasks_per_train,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            control_timestep=control_timestep,
            num_images_per_summary=num_images_per_summary,
            task_reward_dim=task_reward_dim,
            episodes_per_trial=episodes_per_trial,
            # offline data
            override_reward_func=override_reward_func,
            offline_ratio=offline_ratio,
        )

        if verbose:
            print("-- finish constructing agent --")

        ######################################################
        # Replay buffers + observers to add data to them
        ######################################################
        replay_buffers = []
        replay_observers = []
        for _ in range(num_train_tasks):
            replay_buffer_episodic = episodic_replay_buffer.EpisodicReplayBuffer(
                meld_agent.collect_policy.
                trajectory_spec,  # spec of each point stored in here (i.e. Trajectory)
                capacity=replay_buffer_capacity,
                completed_only=
                True,  # in as_dataset, if num_steps is None, this means return full episodes
                # device='GPU:0', # gpu not supported for some reason
                begin_episode_fn=lambda traj: traj.is_first()[
                    0],  # first step of seq we add should be is_first
                end_episode_fn=lambda traj: traj.is_last()[
                    0],  # last step of seq we add should be is_last
                dataset_drop_remainder=
                True,  #`as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items
            )
            replay_buffer = StatefulEpisodicReplayBuffer(
                replay_buffer_episodic)  # adding num_episodes here is bad
            replay_buffers.append(replay_buffer)
            replay_observers.append([replay_buffer.add_sequence])

        if load_offline_data:
            # for each task, has a separate replay buffer for relabeled data
            replay_buffers_withRelabel = []
            replay_observers_withRelabel = []
            for _ in range(num_train_tasks):
                replay_buffer_episodic_withRelabel = episodic_replay_buffer.EpisodicReplayBuffer(
                    meld_agent.collect_policy.
                    trajectory_spec,  # spec of each point stored in here (i.e. Trajectory)
                    capacity=replay_buffer_capacity,
                    completed_only=
                    True,  # in as_dataset, if num_steps is None, this means return full episodes
                    # device='GPU:0', # gpu not supported for some reason
                    begin_episode_fn=lambda traj: traj.is_first()[
                        0],  # first step of seq we add should be is_first
                    end_episode_fn=lambda traj: traj.is_last()[
                        0],  # last step of seq we add should be is_last
                    dataset_drop_remainder=True,
                    # `as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items
                )
                replay_buffer_withRelabel = StatefulEpisodicReplayBuffer(
                    replay_buffer_episodic_withRelabel
                )  # adding num_episodes here is bad
                replay_buffers_withRelabel.append(replay_buffer_withRelabel)
                replay_observers_withRelabel.append(
                    [replay_buffer_withRelabel.add_sequence])

        if verbose:
            print("-- finish constructing replay buffers --")
            print("-- start constructing policies and collect ops --")

        ######################################################
        # Policies
        #####################################################

        # init collect policy (random)
        init_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)

        # eval
        eval_py_policy = py_tf_policy.PyTFPolicy(meld_agent.policy)

        ################################################################################
        # Collect ops : use policies to get data + have the observer put data into corresponding RB
        ################################################################################

        #init collection (with random policy)
        init_collect_ops = []
        for task_idx in range(num_train_tasks):
            # put init data into the rb + track with the train metric
            observers = replay_observers[task_idx] + train_metrics

            # initial collect op
            init_collect_op = DynamicTrialDriver(
                tf_env,
                init_collect_policy,
                num_trials_to_collect=init_collect_trials_per_task,
                observers=observers,
                episodes_per_trial=
                episodes_per_trial,  # policy state will not be reset within these episodes
                max_episode_len=max_episode_len,
            ).run()  # collect one trial
            init_collect_ops.append(init_collect_op)

        # data collection for training (with collect policy)
        collect_ops = []
        for task_idx in range(num_train_tasks):
            collect_op = DynamicTrialDriver(
                tf_env,
                meld_agent.collect_policy,
                num_trials_to_collect=collect_trials_per_task,
                observers=replay_observers[task_idx] +
                train_metrics,  # put data into 1st RB + track with 1st pol metrics
                episodes_per_trial=
                episodes_per_trial,  # policy state will not be reset within these episodes
                max_episode_len=max_episode_len,
            ).run()  # collect one trial
            collect_ops.append(collect_op)

        if verbose:
            print("-- finish constructing policies and collect ops --")
            print("-- start constructing replay buffer->training pipeline --")

        ######################################################
        # replay buffer --> dataset --> iterate to get trajecs for training
        ######################################################

        # get some data from all task replay buffers (even though won't actually train on all of them)
        dataset_iterators = []
        all_tasks_trajectories_fromdense = []
        for task_idx in range(num_train_tasks):
            dataset = replay_buffers[task_idx].as_dataset(
                sample_batch_size=
                sample_episodes_per_task,  # number of episodes to sample
                num_steps=max_episode_len + 1
            ).prefetch(
                3
            )  # +1 to include the last state: a trajectory with n transition has n+1 states
            # iterator to go through the data
            dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
                dataset)
            dataset_iterators.append(dataset_iterator)
            # get sample_episodes_per_task sequences, each of length num_steps
            trajectories_task_i, _ = dataset_iterator.get_next()
            all_tasks_trajectories_fromdense.append(trajectories_task_i)

        if load_offline_data:
            # have separate dataset for relabel data
            dataset_iterators_withRelabel = []
            all_tasks_trajectories_fromdense_withRelabel = []
            for task_idx in range(num_train_tasks):
                dataset = replay_buffers_withRelabel[task_idx].as_dataset(
                    sample_batch_size=
                    sample_episodes_per_task,  # number of episodes to sample
                    num_steps=offline_episode_len + 1
                ).prefetch(
                    3
                )  # +1 to include the last state: a trajectory with n transition has n+1 states
                # iterator to go through the data
                dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
                    dataset)
                dataset_iterators_withRelabel.append(dataset_iterator)
                # get sample_episodes_per_task sequences, each of length num_steps
                trajectories_task_i, _ = dataset_iterator.get_next()
                all_tasks_trajectories_fromdense_withRelabel.append(
                    trajectories_task_i)

        if verbose:
            print("-- finish constructing replay buffer->training pipeline --")
            print("-- start constructing model and AC training ops --")

        ######################################
        # Decoding latent samples into rewards
        ######################################

        latent_samples_1_ph = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=(None, None, meld_agent._model_network.latent1_size))
        latent_samples_2_ph = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=(None, None, meld_agent._model_network.latent2_size))
        decode_rews_op = meld_agent._model_network.decode_latents_into_reward(
            latent_samples_1_ph, latent_samples_2_ph)

        ######################################
        # Model/Actor/Critic train + summary ops
        ######################################

        # train AC on data from replay buffer
        if load_offline_data:
            ac_train_op = meld_agent.train_ac_meld(
                all_tasks_trajectories_fromdense,
                all_tasks_trajectories_fromdense_withRelabel)
        else:
            ac_train_op = meld_agent.train_ac_meld(
                all_tasks_trajectories_fromdense)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        if verbose:
            print("-- finish constructing AC training ops --")

        ############################
        # Model train + summary ops
        ############################

        # train model on data from replay buffer
        if load_offline_data:
            model_train_op, check_step_types = meld_agent.train_model_meld(
                all_tasks_trajectories_fromdense,
                all_tasks_trajectories_fromdense_withRelabel)
        else:
            model_train_op, check_step_types = meld_agent.train_model_meld(
                all_tasks_trajectories_fromdense)

        model_summary_ops, model_summary_ops_2 = [], []
        for summary_op in tf.compat.v1.summary.all_v2_summary_ops():
            if summary_op not in summary_ops:
                model_summary_ops.append(summary_op)

        if verbose:
            print("-- finish constructing model training ops --")
            print("-- start constructing checkpointers --")

        ########################
        # Eval metrics
        ########################

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step,
                                         step_metrics=train_metrics[:2])

        ########################
        # Create savers
        ########################
        train_config_saver = gin.tf.GinConfigSaverHook(train_dir,
                                                       summarize_config=False)
        eval_config_saver = gin.tf.GinConfigSaverHook(eval_dir,
                                                      summarize_config=False)

        ########################
        # Create checkpointers
        ########################

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=meld_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=1)
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=meld_agent.policy,
            global_step=global_step,
            max_to_keep=99999999999
        )  # keep many policy checkpoints, in case of future eval
        rb_checkpointers = []
        for buffer_idx in range(len(replay_buffers)):
            rb_checkpointer = common.Checkpointer(
                ckpt_dir=os.path.join(train_dir, 'replay_buffers/',
                                      "task" + str(buffer_idx)),
                max_to_keep=1,
                replay_buffer=replay_buffers[buffer_idx])
            rb_checkpointers.append(rb_checkpointer)

        if load_offline_data:  # for LOADING data not for checkpointing. No new data going in anyways
            rb_checkpointers_withRelabel = []
            for buffer_idx in range(len(replay_buffers_withRelabel)):
                ckpt_dir = os.path.join(offline_data_dir,
                                        "task" + str(buffer_idx))
                rb_checkpointer = common.Checkpointer(
                    ckpt_dir=ckpt_dir,
                    max_to_keep=99999999999,
                    replay_buffer=replay_buffers_withRelabel[buffer_idx])
                rb_checkpointers_withRelabel.append(rb_checkpointer)
            # Notice: these replay buffers need to follow the same sequence of tasks as the current one

        if verbose:
            print("-- finish constructing checkpointers --")
            print("-- start main training loop --")

        with tf.compat.v1.Session() as sess:

            ########################
            # Initialize
            ########################

            if eval_only:
                sess.run(eval_summary_writer.init())
                load_eval_log(
                    train_eval_dir=train_eval_dir,
                    meld_agent=meld_agent,
                    global_step=global_step,
                    sess=sess,
                    eval_metrics=eval_metrics,
                    eval_py_env=eval_py_env,
                    eval_py_policy=eval_py_policy,
                    num_eval_trials=num_eval_trials,
                    max_episode_len=max_episode_len,
                    episodes_per_trial=episodes_per_trial,
                    log_image_strips=log_image_strips,
                    num_trials_to_render=num_trials_to_render,
                    train_tasks=
                    train_tasks,  # in case want to eval on a train task
                    eval_tasks=eval_tasks,
                    model_net=model_net,
                    render_fps=render_fps,
                    decode_rews_op=decode_rews_op,
                    latent_samples_1_ph=latent_samples_1_ph,
                    latent_samples_2_ph=latent_samples_2_ph,
                )
                return

            # Initialize checkpointing
            train_checkpointer.initialize_or_restore(sess)
            for rb_checkpointer in rb_checkpointers:
                rb_checkpointer.initialize_or_restore(sess)

            if load_offline_data:
                for rb_checkpointer in rb_checkpointers_withRelabel:
                    rb_checkpointer.initialize_or_restore(sess)

            # Initialize dataset iterators
            for dataset_iterator in dataset_iterators:
                sess.run(dataset_iterator.initializer)

            if load_offline_data:
                for dataset_iterator in dataset_iterators_withRelabel:
                    sess.run(dataset_iterator.initializer)

            # Initialize variables
            common.initialize_uninitialized_variables(sess)

            # Initialize summary writers
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            # Initialize savers
            train_config_saver.after_create_session(sess)
            eval_config_saver.after_create_session(sess)
            # Get value of step counter
            global_step_val = sess.run(global_step)

            if verbose:
                print("====== finished initialization ======")

            ################################################################
            # If this is start of new exp (i.e., 1st step) and not continuing old exp
            # eval rand policy + do initial data collection
            ################################################################
            fresh_start = (global_step_val == 0)

            if fresh_start:

                ########################
                # Evaluate initial policy
                ########################

                if eval_interval:
                    logging.info(
                        '\n\nDoing evaluation of initial policy on %d trials with randomly sampled tasks',
                        num_eval_trials)
                    perform_eval_and_summaries_meld(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_eval_trials,
                        max_episode_len,
                        episodes_per_trial,
                        log_image_strips=log_image_strips,
                        num_trials_to_render=num_eval_tasks,
                        eval_tasks=eval_tasks,
                        latent1_size=model_net.latent1_size,
                        latent2_size=model_net.latent2_size,
                        logger=eval_logger,
                        global_step_val=global_step_val,
                        render_fps=render_fps,
                        decode_rews_op=decode_rews_op,
                        latent_samples_1_ph=latent_samples_1_ph,
                        latent_samples_2_ph=latent_samples_2_ph,
                        log_image_observations=log_image_observations,
                    )
                    sess.run(eval_summary_flush_op)
                    logging.info(
                        'Done with evaluation of initial (random) policy.\n\n')

                ########################
                # Initial data collection
                ########################

                logging.info(
                    '\n\nGlobal step %d: Beginning init collect op with random policy. Collecting %dx {%d, %d} trials for each task',
                    global_step_val, init_collect_trials_per_task,
                    max_episode_len, episodes_per_trial)

                init_increment_global_step_op = global_step.assign_add(
                    env_steps_per_trial * init_collect_trials_per_task)

                for task_idx in range(num_train_tasks):
                    logging.info('on task %d / %d', task_idx + 1,
                                 num_train_tasks)
                    py_env.set_task_for_env(train_tasks[task_idx])
                    sess.run([
                        init_collect_ops[task_idx],
                        init_increment_global_step_op
                    ])  # incremented gs in granularity of task

                rb_checkpointer.save(global_step=global_step_val)
                logging.info('Finished init collect.\n\n')

            else:
                logging.info(
                    '\n\nGlobal step %d from loaded experiment: Skipping init collect op.\n\n',
                    global_step_val)

            #########################
            # Create calls
            #########################

            # [1] calls for running the policies to collect training data
            collect_calls = []
            increment_global_step_op = global_step.assign_add(
                env_steps_per_trial * collect_trials_per_task)
            for task_idx in range(num_train_tasks):
                collect_calls.append(
                    sess.make_callable(
                        [collect_ops[task_idx], increment_global_step_op]))

            # [2] call for doing a training step (A + C)
            ac_train_step_call = sess.make_callable([ac_train_op, summary_ops])

            # [3] call for doing a training step (model)
            model_train_step_call = sess.make_callable(
                [model_train_op, check_step_types, model_summary_ops])

            # [4] call for evaluating what global_step number we're on
            global_step_call = sess.make_callable(global_step)

            # reset keeping track of steps/time
            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            with train_summary_writer.as_default(
            ), tf.compat.v2.summary.record_if(True):
                steps_per_second_summary = tf.compat.v2.summary.scalar(
                    name='global_steps_per_sec',
                    data=steps_per_second_ph,
                    step=global_step)

            #################################
            # init model training
            #################################
            if fresh_start:
                logging.info(
                    '\n\nPerforming %d steps of init model training, each step on %d random tasks',
                    init_model_train_steps, num_tasks_per_train)
                for i in range(init_model_train_steps):

                    temp_start = time.time()
                    if i % 100 == 0:
                        print(".... init model training ", i, "/",
                              init_model_train_steps)

                    # init model training
                    total_loss_value_model, check_step_types, _ = model_train_step_call(
                    )

                    if PRINT_TIMING:
                        print("single model train step: ",
                              time.time() - temp_start)

            if verbose:
                print("\n\n\n-- start training loop --\n")

            #################################
            # Training Loop
            #################################
            start_time = time.time()
            for iteration in range(num_iterations):

                if iteration > 0:
                    g.finalize()

                # print("\n\n\niter", iteration, sess.run(curr_iter))
                print("global step", global_step_call())

                logging.info("Iteration: %d, Global step: %d\n", iteration,
                             global_step_val)

                ####################
                # collect data
                ####################
                logging.info(
                    '\nStarting batch data collection. Collecting %d {%d, %d} trials for each of %d tasks',
                    collect_trials_per_task, max_episode_len,
                    episodes_per_trial, num_tasks_to_collect_per_iter)

                # randomly select tasks to collect this iteration
                list_of_collect_task_idxs = np.random.choice(
                    len(train_tasks),
                    num_tasks_to_collect_per_iter,
                    replace=False)
                for count, task_idx in enumerate(list_of_collect_task_idxs):
                    logging.info('on randomly selected task %d / %d',
                                 count + 1, num_tasks_to_collect_per_iter)

                    # set task for the env
                    py_env.set_task_for_env(train_tasks[task_idx])

                    # collect data with collect policy
                    _, policy_state_val = collect_calls[task_idx]()

                logging.info('Finish data collection. Global step: %d\n',
                             global_step_call())

                ####################
                # train model
                ####################
                if (iteration
                        == 0) or ((iteration % model_train_freq == 0) and
                                  (global_step_val < stop_model_training)):
                    logging.info(
                        '\n\nPerforming %d steps of model training, each on %d random tasks',
                        model_train_steps_per_iter, num_tasks_per_train)
                    for model_iter in range(model_train_steps_per_iter):
                        temp_start_2 = time.time()

                        # train model
                        total_loss_value_model, _, _ = model_train_step_call()

                        # print("is logging step", model_iter, sess.run(is_logging_step))
                        if PRINT_TIMING:
                            print("2: single model train step: ",
                                  time.time() - temp_start_2)
                    logging.info('Finish model training. Global step: %d\n',
                                 global_step_call())
                else:
                    print("SKIPPING MODEL TRAINING")

                ####################
                # train actor critic
                ####################
                if iteration % ac_train_freq == 0:
                    logging.info(
                        '\n\nPerforming %d steps of AC training, each on %d random tasks \n\n',
                        ac_train_steps_per_iter, num_tasks_per_train)
                    for ac_iter in range(ac_train_steps_per_iter):
                        temp_start_2_ac = time.time()

                        # train ac
                        total_loss_value_ac, _ = ac_train_step_call()
                        if PRINT_TIMING:
                            print("2: single AC train step: ",
                                  time.time() - temp_start_2_ac)
                logging.info('Finish AC training. Global step: %d\n',
                             global_step_call())

                # add up time
                time_acc += time.time() - start_time

                ####################
                # logging/summaries
                ####################

                ### Eval
                if eval_interval and (iteration % eval_interval == 0):
                    logging.info(
                        '\n\nDoing evaluation of trained policy on %d trials with randomly sampled tasks',
                        num_eval_trials)

                    perform_eval_and_summaries_meld(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_eval_trials,
                        max_episode_len,
                        episodes_per_trial,
                        log_image_strips=log_image_strips,
                        num_trials_to_render=
                        num_trials_to_render,  # hardcoded: or gif will get too long
                        eval_tasks=eval_tasks,
                        latent1_size=model_net.latent1_size,
                        latent2_size=model_net.latent2_size,
                        logger=eval_logger,
                        global_step_val=global_step_call(),
                        render_fps=render_fps,
                        decode_rews_op=decode_rews_op,
                        latent_samples_1_ph=latent_samples_1_ph,
                        latent_samples_2_ph=latent_samples_2_ph,
                        log_image_observations=log_image_observations,
                    )

                ### steps_per_second_summary
                global_step_val = global_step_call()
                if logging_freq_in_iter and (iteration % logging_freq_in_iter
                                             == 0):
                    # log step number + speed (steps/sec)
                    logging.info(
                        'step = %d, loss = %f', global_step_val,
                        total_loss_value_ac.loss + total_loss_value_model.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f env_steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})

                    # reset keeping track of steps/time
                    timed_at_step = global_step_val
                    time_acc = 0

                ### train_checkpoint
                if train_checkpoint_freq_in_iter and (
                        iteration % train_checkpoint_freq_in_iter == 0):
                    train_checkpointer.save(global_step=global_step_val)

                ### policy_checkpointer
                if policy_checkpoint_freq_in_iter and (
                        iteration % policy_checkpoint_freq_in_iter == 0):
                    policy_checkpointer.save(global_step=global_step_val)

                ### rb_checkpointer
                if rb_checkpoint_freq_in_iter and (
                        iteration % rb_checkpoint_freq_in_iter == 0):
                    for rb_checkpointer in rb_checkpointers:
                        rb_checkpointer.save(global_step=global_step_val)
Ejemplo n.º 9
0
def train_eval(
        load_root_dir,
        env_load_fn=None,
        gym_env_wrappers=[],
        monitor=False,
        env_name=None,
        agent_class=None,
        train_metrics_callback=None,
        # SacAgent args
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256),
        # Safety Critic training args
        safety_critic_joint_fc_layers=None,
        safety_critic_lr=3e-4,
        safety_critic_bias_init_val=None,
        safety_critic_kernel_scale=None,
        n_envs=None,
        target_safety=0.2,
        fail_weight=None,
        # Params for train
        num_global_steps=10000,
        batch_size=256,
        # Params for eval
        run_eval=False,
        eval_metrics=[],
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        summary_interval=1000,
        monitor_interval=5000,
        summaries_flush_secs=10,
        debug_summaries=False,
        seed=None):

    if isinstance(agent_class, str):
        assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(
            agent_class)
        agent_class = ALGOS.get(agent_class)

    train_ckpt_dir = osp.join(load_root_dir, 'train')
    rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer')

    py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)

    if monitor:
        vid_path = os.path.join(load_root_dir, 'rollouts')
        monitor_env_wrapper = misc.monitor_freq(1, vid_path)
        monitor_env = gym.make(env_name)
        for wrapper in gym_env_wrappers:
            monitor_env = wrapper(monitor_env)
        monitor_env = monitor_env_wrapper(monitor_env)
        # auto_reset must be False to ensure Monitor works correctly
        monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

    if run_eval:
        eval_dir = os.path.join(load_root_dir, 'eval')
        n_envs = n_envs or num_eval_episodes
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(prefix='EvalMetrics',
                                           buffer_size=num_eval_episodes,
                                           batch_size=n_envs),
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='EvalMetrics',
                buffer_size=num_eval_episodes,
                batch_size=n_envs)
        ] + [
            tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name))
            for m in eval_metrics
        ]
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            seeds = [seed * n_envs + i for i in range(n_envs)]
            try:
                eval_tf_env.pyenv.seed(seeds)
            except:
                pass

    global_step = tf.compat.v1.train.get_or_create_global_step()

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)

    critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
        safety_critic_net = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=critic_joint_fc_layers)
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               safety_critic_network=safety_critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)
    else:
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)

    collect_data_spec = tf_agent.collect_data_spec
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=1000000)
    replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer)

    tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent)
    if agent_class in SAFETY_AGENTS:
        target_safety = target_safety or tf_agent._target_safety
    loaded_train_steps = global_step.numpy()
    logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir,
                 loaded_train_steps)
    global_step.assign(0)
    tf.summary.experimental.set_step(global_step)

    thresholds = [target_safety, 0.5]
    sc_metrics = [
        tf.keras.metrics.AUC(name='safety_critic_auc'),
        tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                        threshold=0.5),
        tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                       thresholds=thresholds),
        tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                        thresholds=thresholds),
        tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                       thresholds=thresholds),
        tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                        thresholds=thresholds)
    ]

    if seed:
        tf.compat.v1.set_random_seed(seed)

    summaries_flush_secs = 10
    timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S')
    offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp)
    config_saver = gin.tf.GinConfigSaverHook(offline_train_dir,
                                             summarize_config=True)
    tf.function(config_saver.after_create_session)()

    sc_summary_writer = tf.compat.v2.summary.create_file_writer(
        offline_train_dir, flush_millis=summaries_flush_secs * 1000)
    sc_summary_writer.set_as_default()

    if safety_critic_kernel_scale is not None:
        ki = tf.compat.v1.variance_scaling_initializer(
            scale=safety_critic_kernel_scale,
            mode='fan_in',
            distribution='truncated_normal')
    else:
        ki = tf.compat.v1.keras.initializers.VarianceScaling(
            scale=1. / 3., mode='fan_in', distribution='uniform')

    if safety_critic_bias_init_val is not None:
        bi = tf.constant_initializer(safety_critic_bias_init_val)
    else:
        bi = None
    sc_net_off = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=safety_critic_joint_fc_layers,
        kernel_initializer=ki,
        value_bias_initializer=bi,
        name='SafetyCriticOffline')
    sc_net_off.create_variables()
    target_sc_net_off = common.maybe_copy_target_network_with_checks(
        sc_net_off, None, 'TargetSafetyCriticNetwork')
    optimizer = tf.keras.optimizers.Adam(safety_critic_lr)
    sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic')
    sc_checkpointer = common.Checkpointer(
        ckpt_dir=sc_net_off_ckpt_dir,
        safety_critic=sc_net_off,
        target_safety_critic=target_sc_net_off,
        optimizer=optimizer,
        global_step=global_step,
        max_to_keep=5)
    sc_checkpointer.initialize_or_restore()

    resample_counter = py_metrics.CounterMetric('ActionResampleCounter')
    eval_policy = agents.SafeActorPolicyRSVar(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=actor_net,
        safety_critic_network=sc_net_off,
        safety_threshold=target_safety,
        resample_counter=resample_counter,
        training=True)

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       num_steps=2,
                                       sample_batch_size=batch_size //
                                       2).prefetch(3)
    data = iter(dataset)
    full_data = replay_buffer.gather_all()

    fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool)
    fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, fail_mask), full_data)
    init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data)
    before_fail_mask = tf.roll(fail_mask, [-1], axis=[1])
    after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1])
    before_fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data)
    after_init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, after_init_mask), full_data)

    filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask))
    filter_mask = tf.pad(
        filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]])
    n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy()

    failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec,
        batch_size=1,
        max_length=n_failures,
        dataset_window_shift=1)
    data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask)

    sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3,
                                               sample_batch_size=batch_size //
                                               2,
                                               num_steps=2).prefetch(3)
    neg_data = iter(sc_dataset_neg)

    get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0]
    eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step,
                                after_init_step, get_action)

    losses = []
    mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss')
    target_update = train_utils.get_target_updater(sc_net_off,
                                                   target_sc_net_off)

    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        while global_step.numpy() < num_global_steps:
            pos_experience, _ = next(data)
            neg_experience, _ = next(neg_data)
            exp = data_utils.concat_batches(pos_experience, neg_experience,
                                            collect_data_spec)
            boundary_mask = tf.logical_not(exp.is_boundary()[:, 0])
            exp = nest_utils.fast_map_structure(
                lambda *x: tf.boolean_mask(*x, boundary_mask), exp)
            safe_rew = exp.observation['task_agn_rew'][:, 1]
            if fail_weight:
                weights = tf.where(tf.cast(safe_rew, tf.bool),
                                   fail_weight / 0.5, (1 - fail_weight) / 0.5)
            else:
                weights = None
            train_loss, sc_loss, lam_loss = train_step(
                exp,
                safe_rew,
                tf_agent,
                sc_net=sc_net_off,
                target_sc_net=target_sc_net_off,
                metrics=sc_metrics,
                weights=weights,
                target_safety=target_safety,
                optimizer=optimizer,
                target_update=target_update,
                debug_summaries=debug_summaries)
            global_step.assign_add(1)
            global_step_val = global_step.numpy()
            losses.append(
                (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy()))
            mean_loss(train_loss)
            with tf.name_scope('Losses'):
                tf.compat.v2.summary.scalar(name='sc_loss',
                                            data=sc_loss,
                                            step=global_step_val)
                tf.compat.v2.summary.scalar(name='lam_loss',
                                            data=lam_loss,
                                            step=global_step_val)
                if global_step_val % summary_interval == 0:
                    tf.compat.v2.summary.scalar(name=mean_loss.name,
                                                data=mean_loss.result(),
                                                step=global_step_val)
            if global_step_val % summary_interval == 0:
                with tf.name_scope('Metrics'):
                    for metric in sc_metrics:
                        if len(tf.squeeze(metric.result()).shape) == 0:
                            tf.compat.v2.summary.scalar(name=metric.name,
                                                        data=metric.result(),
                                                        step=global_step_val)
                        else:
                            fmt_str = '_{}'.format(thresholds[0])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[0],
                                step=global_step_val)
                            fmt_str = '_{}'.format(thresholds[1])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[1],
                                step=global_step_val)
                        metric.reset_states()
            if global_step_val % eval_interval == 0:
                eval_sc(sc_net_off, step=global_step_val)
                if run_eval:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix='EvalMetrics',
                    )
                    if train_metrics_callback is not None:
                        train_metrics_callback(results, global_step_val)
                    metric_utils.log_metrics(eval_metrics)
                    with eval_summary_writer.as_default():
                        for eval_metric in eval_metrics[2:]:
                            eval_metric.tf_summaries(
                                train_step=global_step,
                                step_metrics=eval_metrics[:2])
            if monitor and global_step_val % monitor_interval == 0:
                monitor_time_step = monitor_py_env.reset()
                monitor_policy_state = eval_policy.get_initial_state(1)
                ep_len = 0
                monitor_start = time.time()
                while not monitor_time_step.is_last():
                    monitor_action = eval_policy.action(
                        monitor_time_step, monitor_policy_state)
                    action, monitor_policy_state = monitor_action.action, monitor_action.state
                    monitor_time_step = monitor_py_env.step(action)
                    ep_len += 1
                logging.debug(
                    'saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                    global_step_val, ep_len,
                    time.time() - monitor_start)

            if global_step_val % train_checkpoint_interval == 0:
                sc_checkpointer.save(global_step=global_step_val)
Ejemplo n.º 10
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for SAC."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
    eval_tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=normal_projection_net)
    critic_net = critic_network.CriticNetwork(
        (observation_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers)

    tf_agent = sac_agent.SacAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
        tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
    ]

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer,
        num_steps=initial_collect_steps)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions:
      initial_collect_driver.run = common.function(initial_collect_driver.run)
      collect_driver.run = common.function(collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    # Collect initial replay data.
    logging.info(
        'Initializing replay buffer by collecting experience for %d steps with '
        'a random policy.', initial_collect_steps)
    initial_collect_driver.run()

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        experience, _ = next(iterator)
        train_loss = tf_agent.train(experience)
      time_acc += time.time() - start_time

      if global_step.numpy() % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step.numpy(),
                     train_loss.loss)
        steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step.numpy()
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      if global_step.numpy() % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

      global_step_val = global_step.numpy()
      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
    return train_loss
Ejemplo n.º 11
0
def train_eval(
        root_dir,
        experiment_name,
        train_eval_dir=None,
        universe='gym',
        env_name='HalfCheetah-v2',
        domain_name='cheetah',
        task_name='run',
        action_repeat=1,
        num_iterations=int(1e7),
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        model_network_ctor=model_distribution_network.ModelDistributionNetwork,
        critic_input='state',
        actor_input='state',
        compressor_descriptor='preprocessor_32_3',
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=int(1e5),
        # increase if necessary since buffers with images are huge
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        model_train_steps_per_iteration=1,
        initial_model_train_steps=100000,
        batch_size=256,
        model_batch_size=32,
        sequence_length=4,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        model_learning_rate=1e-4,
        td_errors_loss_fn=functools.partial(
            tf.compat.v1.losses.mean_squared_error, weights=0.5),
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for summaries and logging
        num_images_per_summary=1,
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=0,  # enable if necessary since buffers with images are huge
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        gpu_allow_growth=False,
        gpu_memory_limit=None):
    """A simple train and eval for SLAC."""
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpu_allow_growth:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    if gpu_memory_limit:
        for gpu in gpus:
            tf.config.experimental.set_virtual_device_configuration(
                gpu, [
                    tf.config.experimental.VirtualDeviceConfiguration(
                        memory_limit=gpu_memory_limit)
                ])

    if train_eval_dir is None:
        train_eval_dir = get_train_eval_dir(root_dir, universe, env_name,
                                            domain_name, task_name,
                                            experiment_name)
    train_dir = os.path.join(train_eval_dir, 'train')
    eval_dir = os.path.join(train_eval_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy',
                                       buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthEvalPolicy',
            buffer_size=num_eval_episodes),
    ]
    eval_greedy_metrics = [
        py_metrics.AverageReturnMetric(name='AverageReturnEvalGreedyPolicy',
                                       buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthEvalGreedyPolicy',
            buffer_size=num_eval_episodes),
    ]
    eval_summary_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create the environment.
        trainable_model = model_train_steps_per_iteration != 0
        state_only = (actor_input == 'state' and critic_input == 'state'
                      and not trainable_model
                      and initial_model_train_steps == 0)
        # Save time from unnecessarily rendering observations.
        observations_whitelist = ['state'] if state_only else None
        py_env, eval_py_env = load_environments(
            universe,
            env_name=env_name,
            domain_name=domain_name,
            task_name=task_name,
            observations_whitelist=observations_whitelist,
            action_repeat=action_repeat)
        tf_env = tf_py_environment.TFPyEnvironment(py_env, isolation=True)
        original_control_timestep = get_control_timestep(eval_py_env)
        control_timestep = original_control_timestep * float(action_repeat)
        fps = int(np.round(1.0 / control_timestep))
        render_fps = int(np.round(1.0 / original_control_timestep))

        # Get the data specs from the environment
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        if model_train_steps_per_iteration not in (0,
                                                   train_steps_per_iteration):
            raise NotImplementedError
        model_net = model_network_ctor(observation_spec, action_spec)
        if compressor_descriptor == 'model':
            compressor_net = model_net.compressor
        elif re.match('preprocessor_(\d+)_(\d+)', compressor_descriptor):
            m = re.match('preprocessor_(\d+)_(\d+)', compressor_descriptor)
            filters, n_layers = m.groups()
            filters = int(filters)
            n_layers = int(n_layers)
            compressor_net = compressor_network.Preprocessor(filters,
                                                             n_layers=n_layers)
        elif re.match('compressor_(\d+)', compressor_descriptor):
            m = re.match('compressor_(\d+)', compressor_descriptor)
            filters, = m.groups()
            filters = int(filters)
            compressor_net = compressor_network.Compressor(filters)
        elif re.match('softlearning_(\d+)_(\d+)', compressor_descriptor):
            m = re.match('softlearning_(\d+)_(\d+)', compressor_descriptor)
            filters, n_layers = m.groups()
            filters = int(filters)
            n_layers = int(n_layers)
            compressor_net = compressor_network.SoftlearningPreprocessor(
                filters, n_layers=n_layers)
        elif compressor_descriptor == 'd4pg':
            compressor_net = compressor_network.D4pgPreprocessor()
        else:
            raise NotImplementedError(compressor_descriptor)

        actor_state_size = 0
        for _actor_input in actor_input.split('__'):
            if _actor_input == 'state':
                state_size, = observation_spec['state'].shape
                actor_state_size += state_size
            elif _actor_input == 'latent':
                actor_state_size += model_net.state_size
            elif _actor_input == 'feature':
                actor_state_size += compressor_net.feature_size
            elif _actor_input in ('sequence_feature',
                                  'sequence_action_feature'):
                actor_state_size += compressor_net.feature_size * sequence_length
                if _actor_input == 'sequence_action_feature':
                    actor_state_size += tf.compat.dimension_value(
                        action_spec.shape[0]) * (sequence_length - 1)
            else:
                raise NotImplementedError
        actor_input_spec = tensor_spec.TensorSpec((actor_state_size, ),
                                                  dtype=tf.float32)

        critic_state_size = 0
        for _critic_input in critic_input.split('__'):
            if _critic_input == 'state':
                state_size, = observation_spec['state'].shape
                critic_state_size += state_size
            elif _critic_input == 'latent':
                critic_state_size += model_net.state_size
            elif _critic_input == 'feature':
                critic_state_size += compressor_net.feature_size
            elif _critic_input in ('sequence_feature',
                                   'sequence_action_feature'):
                critic_state_size += compressor_net.feature_size * sequence_length
                if _critic_input == 'sequence_action_feature':
                    critic_state_size += tf.compat.dimension_value(
                        action_spec.shape[0]) * (sequence_length - 1)
            else:
                raise NotImplementedError
        critic_input_spec = tensor_spec.TensorSpec((critic_state_size, ),
                                                   dtype=tf.float32)

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            actor_input_spec, action_spec, fc_layer_params=actor_fc_layers)
        critic_net = critic_network.CriticNetwork(
            (critic_input_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = slac_agent.SlacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            model_network=model_net,
            compressor_network=compressor_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            model_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=model_learning_rate),
            sequence_length=sequence_length,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            trainable_model=trainable_model,
            critic_input=critic_input,
            actor_input=actor_input,
            model_batch_size=model_batch_size,
            control_timestep=control_timestep,
            num_images_per_summary=num_images_per_summary,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)
        eval_greedy_py_policy = py_tf_policy.PyTFPolicy(
            greedy_policy.GreedyPolicy(tf_agent.policy))

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(
                py_metrics.AverageReturnMetric(buffer_size=1)),
            tf_py_metric.TFPyMetric(
                py_metrics.AverageEpisodeLengthMetric(buffer_size=1)),
        ]

        collect_policy = tf_agent.collect_policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)

        initial_policy_state = initial_collect_policy.get_initial_state(1)
        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps).run(
                policy_state=initial_policy_state)

        policy_state = collect_policy.get_initial_state(1)
        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run(
                policy_state=policy_state)

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[-2]

        dataset = replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=batch_size,
            num_steps=sequence_length +
            1).unbatch().filter(_filter_invalid_transition).batch(
                batch_size, drop_remainder=True).prefetch(3)
        dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset)
        trajectories, unused_info = dataset_iterator.get_next()

        train_op = tf_agent.train(trajectories)
        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        if initial_model_train_steps:
            with tf.name_scope('initial'):
                model_train_op = tf_agent.train_model(trajectories)
                model_summary_ops = []
                for summary_op in tf.compat.v1.summary.all_v2_summary_ops():
                    if summary_op not in summary_ops:
                        model_summary_ops.append(summary_op)

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics + eval_greedy_metrics:
                eval_metric.tf_summaries(train_step=global_step,
                                         step_metrics=train_metrics[:2])
            if eval_interval:
                eval_images_ph = tf.compat.v1.placeholder(dtype=tf.uint8,
                                                          shape=[None] * 5)
                eval_images_summary = gif_utils.gif_summary_v2(
                    'ObservationVideoEvalPolicy', eval_images_ph, 1, fps)
                eval_render_images_summary = gif_utils.gif_summary_v2(
                    'VideoEvalPolicy', eval_images_ph, 1, render_fps)
                eval_greedy_images_summary = gif_utils.gif_summary_v2(
                    'ObservationVideoEvalGreedyPolicy', eval_images_ph, 1, fps)
                eval_greedy_render_images_summary = gif_utils.gif_summary_v2(
                    'VideoEvalGreedyPolicy', eval_images_ph, 1, render_fps)

        train_config_saver = gin.tf.GinConfigSaverHook(train_dir,
                                                       summarize_config=False)
        eval_config_saver = gin.tf.GinConfigSaverHook(eval_dir,
                                                      summarize_config=False)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=2)
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step,
                                                  max_to_keep=2)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        with tf.compat.v1.Session() as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)

            # Initialize training.
            sess.run(dataset_iterator.initializer)
            common.initialize_uninitialized_variables(sess)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            train_config_saver.after_create_session(sess)
            eval_config_saver.after_create_session(sess)

            global_step_val = sess.run(global_step)

            if global_step_val == 0:
                if eval_interval:
                    # Initial eval of randomly initialized policy
                    for _eval_metrics, _eval_py_policy, \
                        _eval_render_images_summary, _eval_images_summary in (
                        (eval_metrics, eval_py_policy,
                         eval_render_images_summary, eval_images_summary),
                        (eval_greedy_metrics, eval_greedy_py_policy,
                         eval_greedy_render_images_summary, eval_greedy_images_summary)):
                        compute_summaries(
                            _eval_metrics,
                            eval_py_env,
                            _eval_py_policy,
                            num_episodes=num_eval_episodes,
                            num_episodes_to_render=num_images_per_summary,
                            images_ph=eval_images_ph,
                            render_images_summary=_eval_render_images_summary,
                            images_summary=_eval_images_summary)
                    sess.run(eval_summary_flush_op)

                # Run initial collect.
                logging.info('Global step %d: Running initial collect op.',
                             global_step_val)
                sess.run(initial_collect_op)

                # Checkpoint the initial replay buffer contents.
                rb_checkpointer.save(global_step=global_step_val)

                logging.info('Finished initial collect.')
            else:
                logging.info('Global step %d: Skipping initial collect op.',
                             global_step_val)

            policy_state_val = sess.run(policy_state)
            collect_call = sess.make_callable(collect_op,
                                              feed_list=[policy_state])
            train_step_call = sess.make_callable([train_op, summary_ops])
            if initial_model_train_steps:
                model_train_step_call = sess.make_callable(
                    [model_train_op, model_summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            # steps_per_second summary should always be recorded since it's only called every log_interval steps
            with tf.compat.v2.summary.record_if(True):
                steps_per_second_summary = tf.compat.v2.summary.scalar(
                    name='global_steps_per_sec',
                    data=steps_per_second_ph,
                    step=global_step)

            for iteration in range(global_step_val,
                                   initial_model_train_steps + num_iterations):
                start_time = time.time()
                if iteration < initial_model_train_steps:
                    total_loss_val, _ = model_train_step_call()
                else:
                    time_step_val, policy_state_val = collect_call(
                        policy_state_val)
                    for _ in range(train_steps_per_iteration):
                        total_loss_val, _ = train_step_call()

                time_acc += time.time() - start_time
                global_step_val = global_step_call()
                if log_interval and global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss_val.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if (train_checkpoint_interval
                        and global_step_val % train_checkpoint_interval == 0):
                    train_checkpointer.save(global_step=global_step_val)

                if iteration < initial_model_train_steps:
                    continue

                if eval_interval and global_step_val % eval_interval == 0:
                    for _eval_metrics, _eval_py_policy, \
                        _eval_render_images_summary, _eval_images_summary in (
                        (eval_metrics, eval_py_policy,
                         eval_render_images_summary, eval_images_summary),
                        (eval_greedy_metrics, eval_greedy_py_policy,
                         eval_greedy_render_images_summary, eval_greedy_images_summary)):
                        compute_summaries(
                            _eval_metrics,
                            eval_py_env,
                            _eval_py_policy,
                            num_episodes=num_eval_episodes,
                            num_episodes_to_render=num_images_per_summary,
                            images_ph=eval_images_ph,
                            render_images_summary=_eval_render_images_summary,
                            images_summary=_eval_images_summary)
                    sess.run(eval_summary_flush_op)

                if (policy_checkpoint_interval
                        and global_step_val % policy_checkpoint_interval == 0):
                    policy_checkpointer.save(global_step=global_step_val)

                if (rb_checkpoint_interval
                        and global_step_val % rb_checkpoint_interval == 0):
                    rb_checkpointer.save(global_step=global_step_val)
Ejemplo n.º 12
0
def train(
        root_dir,
        env_load_fn=suite_gym.load,
        env_name='CartPole-v0',
        env_name_eval=None,
        num_parallel_environments=1,
        agent_class=None,
        initial_collect_random=True,
        initial_collect_driver_class=None,
        collect_driver_class=None,
        num_global_steps=100000,
        train_steps_per_iteration=1,
        clear_rb_after_train_steps=None,  # Defaults to True for ON_POLICY_AGENTS
        train_metrics=None,
        # Params for eval
        run_eval=False,
        num_eval_episodes=30,
        eval_interval=1000,
        eval_metrics_callback=None,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        keep_rb_checkpoint=False,
        train_sequence_length=1,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        early_termination_fn=None,
        env_metric_factories=None):

    eval_interval_counter = IntervalCounter(eval_interval)
    train_checkpoint_interval_counter = IntervalCounter(
        train_checkpoint_interval)
    policy_checkpoint_interval_counter = IntervalCounter(
        policy_checkpoint_interval)
    rb_checkpoint_interval_counter = IntervalCounter(rb_checkpoint_interval)
    log_interval_counter = IntervalCounter(log_interval)
    summary_interval_counter = IntervalCounterTf(summary_interval)

    if not agent_class:
        raise ValueError(
            'The `agent_class` parameter of trainer.train must be set.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')
    if not tf.io.gfile.exists(saved_model_dir):
        tf.io.gfile.makedirs(saved_model_dir)

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    def make_possibly_parallel_environment(env_name_):
        """Returns a function creating env_name_, possibly a parallel one."""
        if num_parallel_environments == 1:
            return env_load_fn(env_name_)
        else:
            return parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name_)] * num_parallel_environments)

    def make_tf_py_envs(env):
        """Convert env to tf if needed."""
        if isinstance(env, py_environment.PyEnvironment):
            tf_env = tf_py_environment.TFPyEnvironment(env)
            py_env = env
        else:
            tf_env = env
            py_env = None  # Can't generically convert to PyEnvironment.
        return tf_env, py_env

    eval_py_env = None
    if run_eval:
        if env_name_eval is None: env_name_eval = env_name
        eval_env = make_possibly_parallel_environment(env_name_eval)
        eval_tf_env, eval_py_env = make_tf_py_envs(eval_env)

        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(batch_size=eval_tf_env.batch_size,
                                           buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=eval_tf_env.batch_size,
                buffer_size=num_eval_episodes),
        ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    # should_summary = tf.constant(summary_interval_counter.should_trigger(global_step.numpy()))
    # print("should_summary", should_summary)
    with tf.compat.v2.summary.record_if(
            lambda: summary_interval_counter.should_trigger(global_step)):
        env = make_possibly_parallel_environment(env_name)
        tf_env, py_env = make_tf_py_envs(env)

        environment_specs.set_observation_spec(tf_env.observation_spec())
        environment_specs.set_action_spec(tf_env.action_spec())

        # Agent params configured with gin.
        agent = agent_class(tf_env.time_step_spec(),
                            tf_env.action_spec(),
                            train_step_counter=global_step)
        agent.initialize()

        if clear_rb_after_train_steps is None:
            # Default is to clear RB for ON_POLICY_AGENTS, only.
            clear_rb_after_train_steps = isinstance(agent, ON_POLICY_AGENTS)

        if run_eval:
            eval_policy = greedy_policy.GreedyPolicy(agent.policy)

        if not train_metrics:
            train_metrics = [
                tf_metrics.NumberOfEpisodes(),
                tf_metrics.EnvironmentSteps(),
                tf_metrics.AverageReturnMetric(batch_size=tf_env.batch_size,
                                               buffer_size=log_interval *
                                               tf_env.batch_size),
                tf_metrics.AverageEpisodeLengthMetric(
                    batch_size=tf_env.batch_size,
                    buffer_size=log_interval * tf_env.batch_size),
            ]
        else:
            train_metrics = list(train_metrics)

        if env_metric_factories:
            for metric_factory in env_metric_factories:
                py_metric = metric_factory(environment=py_env)
                train_metrics.append(tf_py_metric.TFPyMetric(py_metric))

        logging.info('Allocating replay buffer ...')
        # Add to replay buffer and other agent specific observers.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            agent.collect_data_spec)
        logging.info('RB capacity: %i', replay_buffer.capacity)
        agent_observers = [replay_buffer.add_batch]
        initial_collect_policy = agent.collect_policy
        if initial_collect_random:
            initial_collect_policy = random_tf_policy.RandomTFPolicy(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                info_spec=agent.collect_policy.info_spec)

        collect_policy = agent.collect_policy

        collect_driver = collect_driver_class(tf_env,
                                              collect_policy,
                                              observers=agent_observers +
                                              train_metrics)

        rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            max_to_keep=1,
            agent=agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  max_to_keep=None,
                                                  policy=agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir,
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        saved_model = policy_saver.PolicySaver(greedy_policy.GreedyPolicy(
            agent.policy),
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        collect_driver.run = common.function(collect_driver.run)
        agent.train = common.function(agent.train)

        if not rb_checkpointer.checkpoint_exists:
            logging.info('Performing initial collection ...')
            common.function(
                initial_collect_driver_class(tf_env,
                                             initial_collect_policy,
                                             observers=agent_observers +
                                             train_metrics).run)()

            if run_eval:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

        # This is only used for PPO Agents.
        # The dataset is repeated for `train_steps_per_iteration` which represents
        # the number of epochs we loop through during training.
        def get_data_iter_repeated(replay_buffer):
            dataset = replay_buffer.as_dataset(
                sample_batch_size=num_parallel_environments,
                num_steps=train_sequence_length + 1,
                num_parallel_calls=3,
                single_deterministic_pass=True).repeat(
                    train_steps_per_iteration)
            if len([1 for _ in dataset]) == 0:
                logging.warning('PPO Agent replay buffer as dataset is empty')
            return iter(dataset)

        # For off policy agents, one iterator is created for the entire training
        # process. This is different from PPO agents whose iterators are reset
        # in the training loop.
        if not isinstance(agent, ON_POLICY_AGENTS):
            dataset = replay_buffer.as_dataset(
                num_parallel_calls=3,
                num_steps=train_sequence_length + 1).prefetch(3)
            iterator = iter(dataset)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)
        timed_at_step = global_step.numpy()
        time_acc = 0

        def save_policy(global_step_value):
            """Saves policy using both checkpoint saver and saved model."""
            policy_checkpointer.save(global_step=global_step_value)
            saved_model_path = os.path.join(
                saved_model_dir,
                'policy_' + ('%d' % global_step_value).zfill(8))
            saved_model.save(saved_model_path)

        if global_step.numpy() == 0:
            # Save an initial checkpoint so the evaluator runs for global_step=0.
            save_policy(global_step.numpy())

        @common.function
        def train_step(data_iterator):
            experience = next(data_iterator)[0]
            return agent.train(experience)

        @common.function
        def train_with_gather_all():
            return agent.train(replay_buffer.gather_all())

        if not early_termination_fn:
            early_termination_fn = lambda: False

        loss_diverged = False
        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0

        # Save operative config as late as possible to include used configurables.
        if global_step.numpy() == 0:
            config_filename = os.path.join(
                train_dir,
                'operative_config-{}.gin'.format(global_step.numpy()))
            with tf.io.gfile.GFile(config_filename, 'wb') as f:
                f.write(gin.operative_config_str())

        logging.info('Training ...')
        while (global_step.numpy() <= num_global_steps
               and not early_termination_fn()):
            # Collect and train.
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step, policy_state=policy_state)
            if isinstance(agent, PPO_AGENTS):
                iterator = get_data_iter_repeated(replay_buffer)
            for _ in range(train_steps_per_iteration):
                if isinstance(agent, REINFORCE_AGENTS):
                    total_loss = train_with_gather_all()
                else:
                    total_loss = train_step(iterator)
            total_loss = total_loss.loss

            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    loss_diverged = True
                    break
            else:
                loss_divergence_counter = 0

            if clear_rb_after_train_steps:
                replay_buffer.clear()
            time_acc += time.time() - start_time

            should_log = log_interval_counter.should_trigger(
                global_step.numpy())
            if should_log:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             total_loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

                if should_log:
                    hpt = hypertune.HyperTune()
                    hpt.report_hyperparameter_tuning_metric(
                        hyperparameter_metric_tag=train_metric.name,
                        metric_value=train_metric.result(),
                        global_step=global_step)
                    print("Reported", train_metric.name, global_step.numpy())

            if train_checkpoint_interval_counter.should_trigger(
                    global_step.numpy()):
                train_checkpointer.save(global_step=global_step.numpy())
                print("train_checkpoint", global_step.numpy())

            if policy_checkpoint_interval_counter.should_trigger(
                    global_step.numpy()):
                save_policy(global_step.numpy())
                print("policy_checkpoint", global_step.numpy())

            if rb_checkpoint_interval_counter.should_trigger(
                    global_step.numpy()):
                rb_checkpointer.save(global_step=global_step.numpy())
                print("rb_checkpoint", global_step.numpy())

            if run_eval and eval_interval_counter.should_trigger(
                    global_step.numpy()):
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

    if not keep_rb_checkpoint:
        cleanup_checkpoints(rb_ckpt_dir)

    if py_env:
        py_env.close()
    if eval_py_env:
        eval_py_env.close()

    # Save final operative config that will also have all configurables used in
    # the training loop for the first time.
    config_filename = os.path.join(train_dir, 'operative_config-final.gin')
    with tf.io.gfile.GFile(config_filename, 'wb') as f:
        f.write(gin.operative_config_str())

    if loss_diverged:
        # Raise an error at the very end after the cleanup.
        raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
            total_loss, global_step.numpy()))
Ejemplo n.º 13
0
def train_eval(
        root_dir,
        load_root_dir=None,
        env_load_fn=None,
        gym_env_wrappers=[],
        monitor=False,
        env_name=None,
        agent_class=None,
        initial_collect_driver_class=None,
        collect_driver_class=None,
        online_driver_class=dynamic_episode_driver.DynamicEpisodeDriver,
        num_global_steps=1000000,
        train_steps_per_iteration=1,
        train_metrics=None,
        eval_metrics=None,
        train_metrics_callback=None,
        # Params for SacAgent args
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256),
        # Safety Critic training args
        train_sc_steps=10,
        train_sc_interval=1000,
        online_critic=False,
        n_envs=None,
        finetune_sc=False,
        # Ensemble Critic training args
        n_critics=30,
        critic_learning_rate=3e-4,
        # Wcpg Critic args
        critic_preprocessing_layer_size=256,
        actor_preprocessing_layer_size=256,
        # Params for train
        batch_size=256,
        # Params for eval
        run_eval=False,
        num_eval_episodes=1,
        max_episode_len=500,
        eval_interval=10000,
        eval_metrics_callback=None,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        keep_rb_checkpoint=False,
        log_interval=1000,
        summary_interval=1000,
        monitor_interval=1000,
        summaries_flush_secs=10,
        early_termination_fn=None,
        debug_summaries=False,
        seed=None,
        eager_debug=False,
        env_metric_factories=None):  # pylint: disable=unused-argument
    """A simple train and eval for SC-SAC."""

    n_envs = n_envs or num_eval_episodes
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    train_metrics = train_metrics or []
    eval_metrics = eval_metrics or []
    sc_metrics = eval_metrics or []

    if online_critic:
        sc_dir = os.path.join(root_dir, 'sc')
        sc_summary_writer = tf.compat.v2.summary.create_file_writer(
            sc_dir, flush_millis=summaries_flush_secs * 1000)
        sc_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=n_envs,
                                           name='SafeAverageReturn'),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes,
                batch_size=n_envs,
                name='SafeAverageEpisodeLength')
        ] + [tf_py_metric.TFPyMetric(m) for m in sc_metrics]
        sc_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            sc_tf_env.seed([seed + i for i in range(n_envs)])

    if run_eval:
        eval_dir = os.path.join(root_dir, 'eval')
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=n_envs),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=n_envs),
        ] + [tf_py_metric.TFPyMetric(m) for m in eval_metrics]
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            eval_tf_env.seed([seed + n_envs + i for i in range(n_envs)])

    if monitor:
        vid_path = os.path.join(root_dir, 'rollouts')
        monitor_env_wrapper = misc.monitor_freq(1, vid_path)
        monitor_env = gym.make(env_name)
        for wrapper in gym_env_wrappers:
            monitor_env = wrapper(monitor_env)
        monitor_env = monitor_env_wrapper(monitor_env)
        # auto_reset must be False to ensure Monitor works correctly
        monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        if seed:
            tf_env.seed(seed + 2 * n_envs + i for i in range(n_envs))
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        logging.debug('obs spec: %s', observation_spec)
        logging.debug('action spec: %s', action_spec)

        if agent_class:  #is not wcpg_agent.WcpgAgent:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                observation_spec,
                action_spec,
                fc_layer_params=actor_fc_layers,
                continuous_projection_net=agents.normal_projection_net)
            critic_net = agents.CriticNetwork(
                (observation_spec, action_spec),
                joint_fc_layer_params=critic_joint_fc_layers)
        else:
            alpha_spec = tensor_spec.BoundedTensorSpec(shape=(),
                                                       dtype=tf.float32,
                                                       minimum=0.,
                                                       maximum=1.,
                                                       name='alpha')
            input_tensor_spec = (observation_spec, action_spec, alpha_spec)
            critic_preprocessing_layers = (
                tf.keras.layers.Dense(critic_preprocessing_layer_size),
                tf.keras.layers.Dense(critic_preprocessing_layer_size),
                tf.keras.layers.Lambda(lambda x: x))
            critic_net = agents.DistributionalCriticNetwork(
                input_tensor_spec,
                joint_fc_layer_params=critic_joint_fc_layers)
            actor_preprocessing_layers = (
                tf.keras.layers.Dense(actor_preprocessing_layer_size),
                tf.keras.layers.Dense(actor_preprocessing_layer_size),
                tf.keras.layers.Lambda(lambda x: x))
            actor_net = agents.WcpgActorNetwork(
                input_tensor_spec,
                preprocessing_layers=actor_preprocessing_layers)

        if agent_class in SAFETY_AGENTS:
            safety_critic_net = agents.CriticNetwork(
                (observation_spec, action_spec),
                joint_fc_layer_params=critic_joint_fc_layers)
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   actor_network=actor_net,
                                   critic_network=critic_net,
                                   safety_critic_network=safety_critic_net,
                                   train_step_counter=global_step,
                                   debug_summaries=debug_summaries)
        elif agent_class is ensemble_sac_agent.EnsembleSacAgent:
            critic_nets, critic_optimizers = [critic_net], [
                tf.keras.optimizers.Adam(critic_learning_rate)
            ]
            for _ in range(n_critics - 1):
                critic_nets.append(
                    agents.CriticNetwork(
                        (observation_spec, action_spec),
                        joint_fc_layer_params=critic_joint_fc_layers))
                critic_optimizers.append(
                    tf.keras.optimizers.Adam(critic_learning_rate))
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   actor_network=actor_net,
                                   critic_network=critic_nets,
                                   critic_optimizers=critic_optimizers,
                                   debug_summaries=debug_summaries)
        else:  # assume is using SacAgent
            logging.debug(critic_net.input_tensor_spec)
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   actor_network=actor_net,
                                   critic_network=critic_net,
                                   train_step_counter=global_step,
                                   debug_summaries=debug_summaries)

        tf_agent.initialize()

        # Make the replay buffer.
        collect_data_spec = tf_agent.collect_data_spec

        logging.debug('Allocating replay buffer ...')
        # Add to replay buffer and other agent specific observers.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            collect_data_spec, batch_size=1, max_length=1000000)
        logging.debug('RB capacity: %i', replay_buffer.capacity)
        logging.debug('ReplayBuffer Collect data spec: %s', collect_data_spec)

        agent_observers = [replay_buffer.add_batch]
        if online_critic:
            online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                collect_data_spec,
                batch_size=1,
                max_length=max_episode_len * num_eval_episodes)
            agent_observers.append(online_replay_buffer.add_batch)

            online_rb_ckpt_dir = os.path.join(train_dir,
                                              'online_replay_buffer')
            online_rb_checkpointer = common.Checkpointer(
                ckpt_dir=online_rb_ckpt_dir,
                max_to_keep=1,
                replay_buffer=online_replay_buffer)

            clear_rb = online_replay_buffer.clear

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

        if not online_critic:
            eval_policy = tf_agent.policy
            collect_policy = tf_agent.collect_policy
        else:
            eval_policy = tf_agent.policy  # pylint: disable=protected-access
            collect_policy = tf_agent.collect_policy  # pylint: disable=protected-access
            online_collect_policy = tf_agent._safe_policy

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        if agent_class in SAFETY_AGENTS:
            safety_critic_checkpointer = common.Checkpointer(
                ckpt_dir=sc_dir,
                safety_critic=tf_agent._safety_critic_network,  # pylint: disable=protected-access
                global_step=global_step)
        rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')
        rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir,
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        if load_root_dir:
            load_root_dir = os.path.expanduser(load_root_dir)
            load_train_dir = os.path.join(load_root_dir, 'train')
            misc.load_pi_ckpt(load_train_dir, tf_agent)  # loads tf_agent

        if load_root_dir is None:
            train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()
        if agent_class in SAFETY_AGENTS:
            safety_critic_checkpointer.initialize_or_restore()

        env_metrics = []
        if env_metric_factories:
            for env_metric in env_metric_factories:
                env_metrics.append(
                    tf_py_metric.TFPyMetric(env_metric([py_env.gym])))
                # TODO: get env factory with parallel py envs
                # if run_eval:
                #   eval_metrics.append(env_metric([env.gym for env in eval_tf_env.pyenv._envs]))
                # if online_critic:
                #   sc_metrics.append(env_metric([env.gym for env in sc_tf_env.pyenv._envs]))

        collect_driver = collect_driver_class(tf_env,
                                              collect_policy,
                                              observers=agent_observers +
                                              train_metrics + env_metrics)
        if online_critic:
            logging.debug('online driver class: %s', online_driver_class)
            if online_driver_class is safe_dynamic_episode_driver.SafeDynamicEpisodeDriver:
                online_temp_buffer = episodic_replay_buffer.EpisodicReplayBuffer(
                    collect_data_spec)
                online_temp_buffer_stateful = episodic_replay_buffer.StatefulEpisodicReplayBuffer(
                    online_temp_buffer, num_episodes=num_eval_episodes)
                online_driver = safe_dynamic_episode_driver.SafeDynamicEpisodeDriver(
                    sc_tf_env,
                    online_collect_policy,
                    online_temp_buffer,
                    online_replay_buffer,
                    observers=[online_temp_buffer_stateful.add_batch] +
                    sc_metrics,
                    num_episodes=num_eval_episodes)
            else:
                online_driver = online_driver_class(
                    sc_tf_env,
                    online_collect_policy,
                    observers=[online_replay_buffer.add_batch] + sc_metrics,
                    num_episodes=num_eval_episodes)
            online_driver.run = common.function(online_driver.run)

        if not eager_debug:
            config_saver = gin.tf.GinConfigSaverHook(train_dir,
                                                     summarize_config=True)
            tf.function(config_saver.after_create_session)()

        if agent_class is sac_agent.SacAgent:
            collect_driver.run = common.function(collect_driver.run)
        if eager_debug:
            tf.config.experimental_run_functions_eagerly(True)

        if not rb_checkpointer.checkpoint_exists:
            logging.info('Performing initial collection ...')
            initial_collect_driver_class(tf_env,
                                         initial_collect_policy,
                                         observers=agent_observers +
                                         train_metrics + env_metrics).run()
            last_id = replay_buffer._get_last_id()  # pylint: disable=protected-access
            logging.info('Data saved after initial collection: %d steps',
                         last_id)
            if online_critic:
                last_id = online_replay_buffer._get_last_id()  # pylint: disable=protected-access
                logging.debug(
                    'Data saved in online buffer after initial collection: %d steps',
                    last_id)

        if run_eval:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='EvalMetrics',
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)
        if online_critic:
            online_dataset = online_replay_buffer.as_dataset(
                num_parallel_calls=3,
                sample_batch_size=batch_size,
                num_steps=2).prefetch(3)
            online_iterator = iter(online_dataset)
            critic_metrics = [
                tf.keras.metrics.AUC(name='safety_critic_auc'),
                tf.keras.metrics.TruePositives(name='safety_critic_tp'),
                tf.keras.metrics.FalsePositives(name='safety_critic_fp'),
                tf.keras.metrics.TrueNegatives(name='safety_critic_tn'),
                tf.keras.metrics.FalseNegatives(name='safety_critic_fn'),
                tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc')
            ]

            @common.function
            def critic_train_step():
                """Builds critic training step."""
                start_time = time.time()
                experience, buf_info = next(online_iterator)
                if env_name.split('-')[0] in SAFETY_ENVS:
                    safe_rew = experience.observation['task_agn_rew'][:, 1]
                else:
                    safe_rew = misc.process_replay_buffer(online_replay_buffer,
                                                          as_tensor=True)
                    safe_rew = tf.gather(safe_rew,
                                         tf.squeeze(buf_info.ids),
                                         axis=1)
                ret = tf_agent.train_sc(experience,
                                        safe_rew,
                                        metrics=critic_metrics,
                                        weights=None)
                logging.debug('critic train step: {} sec'.format(time.time() -
                                                                 start_time))
                return ret

        @common.function
        def train_step():
            experience, _ = next(iterator)
            ret = tf_agent.train(experience)
            return ret

        if not early_termination_fn:
            early_termination_fn = lambda: False

        loss_diverged = False
        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0
        mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss')

        if online_critic:
            logging.debug('starting safety critic pretraining')
            safety_eps = tf_agent._safe_policy._safety_threshold
            tf_agent._safe_policy._safety_threshold = 0.6
            resample_counter = online_collect_policy._resample_counter
            mean_resample_ac = tf.keras.metrics.Mean(
                name='mean_unsafe_ac_freq')
            # don't fine-tune safety critic
            if (global_step.numpy() == 0 and load_root_dir is None):
                for _ in range(train_sc_steps):
                    sc_loss, lambda_loss = critic_train_step()  # pylint: disable=unused-variable
            tf_agent._safe_policy._safety_threshold = safety_eps

        logging.debug('starting policy pretraining')
        while (global_step.numpy() <= num_global_steps
               and not early_termination_fn()):
            # Collect and train.
            start_time = time.time()
            current_step = global_step.numpy()

            if online_critic:
                mean_resample_ac(resample_counter.result())
                resample_counter.reset()
                if time_step is None or time_step.is_last():
                    resample_ac_freq = mean_resample_ac.result()
                    mean_resample_ac.reset_states()

            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            logging.debug('policy eval: {} sec'.format(time.time() -
                                                       start_time))

            train_time = time.time()
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
                mean_train_loss(train_loss.loss)
            if current_step == 0:
                logging.debug('train policy: {} sec'.format(time.time() -
                                                            train_time))

            if online_critic and current_step % train_sc_interval == 0:
                batch_time_step = sc_tf_env.reset()
                batch_policy_state = online_collect_policy.get_initial_state(
                    sc_tf_env.batch_size)
                online_driver.run(time_step=batch_time_step,
                                  policy_state=batch_policy_state)
                for _ in range(train_sc_steps):
                    sc_loss, lambda_loss = critic_train_step()  # pylint: disable=unused-variable

                metric_utils.log_metrics(sc_metrics)
                with sc_summary_writer.as_default():
                    for sc_metric in sc_metrics:
                        sc_metric.tf_summaries(train_step=global_step,
                                               step_metrics=sc_metrics[:2])
                    tf.compat.v2.summary.scalar(name='resample_ac_freq',
                                                data=resample_ac_freq,
                                                step=global_step)

            total_loss = mean_train_loss.result()
            mean_train_loss.reset_states()
            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    loss_diverged = True
                    logging.debug(
                        'Loss diverged, critic_loss: %s, actor_loss: %s, alpha_loss: %s',
                        train_loss.extra.critic_loss,
                        train_loss.extra.actor_loss,
                        train_loss.extra.alpha_loss)
                    break
            else:
                loss_divergence_counter = 0

            time_acc += time.time() - start_time

            if current_step % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             total_loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            train_results = []
            for train_metric in train_metrics:
                if isinstance(train_metric, (metrics.AverageEarlyFailureMetric,
                                             metrics.AverageFallenMetric,
                                             metrics.AverageSuccessMetric)):
                    # Plot failure as a fn of return
                    train_metric.tf_summaries(train_step=global_step,
                                              step_metrics=train_metrics[:3])
                else:
                    train_metric.tf_summaries(train_step=global_step,
                                              step_metrics=train_metrics[:2])
                train_results.append(
                    (train_metric.name, train_metric.result().numpy()))
            if env_metrics:
                for env_metric in env_metrics:
                    env_metric.tf_summaries(train_step=global_step,
                                            step_metrics=train_metrics[:2])
                    train_results.append(
                        (env_metric.name, env_metric.result().numpy()))
            if online_critic:
                for critic_metric in critic_metrics:
                    train_results.append(
                        (critic_metric.name, critic_metric.result().numpy()))
                    critic_metric.reset_states()
            if train_metrics_callback is not None:
                train_metrics_callback(collections.OrderedDict(train_results),
                                       global_step.numpy())

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)
                if agent_class in SAFETY_AGENTS:
                    safety_critic_checkpointer.save(
                        global_step=global_step_val)

            if rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0:
                if online_critic:
                    online_rb_checkpointer.save(global_step=global_step_val)
                rb_checkpointer.save(global_step=global_step_val)
            elif online_critic:
                clear_rb()

            if run_eval and global_step_val % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='EvalMetrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step_val)
                metric_utils.log_metrics(eval_metrics)

            if monitor and current_step % monitor_interval == 0:
                monitor_time_step = monitor_py_env.reset()
                monitor_policy_state = eval_policy.get_initial_state(1)
                ep_len = 0
                monitor_start = time.time()
                while not monitor_time_step.is_last():
                    monitor_action = eval_policy.action(
                        monitor_time_step, monitor_policy_state)
                    action, monitor_policy_state = monitor_action.action, monitor_action.state
                    monitor_time_step = monitor_py_env.step(action)
                    ep_len += 1
                monitor_py_env.reset()
                logging.debug(
                    'saved rollout at timestep {}, rollout length: {}, {} sec'.
                    format(global_step_val, ep_len,
                           time.time() - monitor_start))

            logging.debug('iteration time: {} sec'.format(time.time() -
                                                          start_time))

    if not keep_rb_checkpoint:
        misc.cleanup_checkpoints(rb_ckpt_dir)

    if loss_diverged:
        # Raise an error at the very end after the cleanup.
        raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
            total_loss, global_step.numpy()))

    return total_loss
Ejemplo n.º 14
0
def train_eval(
    root_dir,
    load_root_dir=None,
    env_load_fn=None,
    gym_env_wrappers=[],
    monitor=False,
    env_name=None,
    agent_class=None,
    initial_collect_driver_class=None,
    collect_driver_class=None,
    online_driver_class=dynamic_episode_driver.DynamicEpisodeDriver,
    num_global_steps=1000000,
    rb_size=None,
    train_steps_per_iteration=1,
    train_metrics=None,
    eval_metrics=None,
    train_metrics_callback=None,
    # SacAgent args
    actor_fc_layers=(256, 256),
    critic_joint_fc_layers=(256, 256),
    # Safety Critic training args
    sc_rb_size=None,
    target_safety=None,
    train_sc_steps=10,
    train_sc_interval=1000,
    online_critic=False,
    n_envs=None,
    finetune_sc=False,
    pretraining=True,
    lambda_schedule_nsteps=0,
    lambda_initial=0.,
    lambda_final=1.,
    kstep_fail=0,
    # Ensemble Critic training args
    num_critics=None,
    critic_learning_rate=3e-4,
    # Wcpg Critic args
    critic_preprocessing_layer_size=256,
    # Params for train
    batch_size=256,
    # Params for eval
    run_eval=False,
    num_eval_episodes=10,
    eval_interval=1000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    keep_rb_checkpoint=False,
    log_interval=1000,
    summary_interval=1000,
    monitor_interval=5000,
    summaries_flush_secs=10,
    early_termination_fn=None,
    debug_summaries=False,
    seed=None,
    eager_debug=False,
    env_metric_factories=None,
    wandb=False):  # pylint: disable=unused-argument

  """train and eval script for SQRL."""
  if isinstance(agent_class, str):
    assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(agent_class)
    agent_class = ALGOS.get(agent_class)
  n_envs = n_envs or num_eval_episodes
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')

  # =====================================================================#
  #  Setup summary metrics, file writers, and create env                 #
  # =====================================================================#
  train_summary_writer = tf.compat.v2.summary.create_file_writer(
    train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  train_metrics = train_metrics or []
  eval_metrics = eval_metrics or []

  updating_sc = online_critic and (not load_root_dir or finetune_sc)
  logging.debug('updating safety critic: %s', updating_sc)

  if seed:
    tf.compat.v1.set_random_seed(seed)

  if agent_class in SAFETY_AGENTS:
    if online_critic:
      sc_tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
          [lambda: env_load_fn(env_name)] * n_envs
        ))
      if seed:
        seeds = [seed * n_envs + i for i in range(n_envs)]
        try:
          sc_tf_env.pyenv.seed(seeds)
        except:
          pass

  if run_eval:
    eval_dir = os.path.join(root_dir, 'eval')
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
                     tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=n_envs),
                     tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes, batch_size=n_envs),
                   ] + [tf_py_metric.TFPyMetric(m) for m in eval_metrics]
    eval_tf_env = tf_py_environment.TFPyEnvironment(
      parallel_py_environment.ParallelPyEnvironment(
        [lambda: env_load_fn(env_name)] * n_envs
      ))
    if seed:
      try:
        for i, pyenv in enumerate(eval_tf_env.pyenv.envs):
          pyenv.seed(seed * n_envs + i)
      except:
        pass
  elif 'Drunk' in env_name:
    # Just visualizes trajectories in drunk spider environment
    eval_tf_env = tf_py_environment.TFPyEnvironment(
      env_load_fn(env_name))
  else:
    eval_tf_env = None

  if monitor:
    vid_path = os.path.join(root_dir, 'rollouts')
    monitor_env_wrapper = misc.monitor_freq(1, vid_path)
    monitor_env = gym.make(env_name)
    for wrapper in gym_env_wrappers:
      monitor_env = wrapper(monitor_env)
    monitor_env = monitor_env_wrapper(monitor_env)
    # auto_reset must be False to ensure Monitor works correctly
    monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

  global_step = tf.compat.v1.train.get_or_create_global_step()

  with tf.summary.record_if(
          lambda: tf.math.equal(global_step % summary_interval, 0)):
    py_env = env_load_fn(env_name)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)
    if seed:
      try:
        for i, pyenv in enumerate(tf_env.pyenv.envs):
          pyenv.seed(seed * n_envs + i)
      except:
        pass
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    logging.debug('obs spec: %s', observation_spec)
    logging.debug('action spec: %s', action_spec)

    # =====================================================================#
    #  Setup agent class                                                   #
    # =====================================================================#

    if agent_class == wcpg_agent.WcpgAgent:
      alpha_spec = tensor_spec.BoundedTensorSpec(shape=(1,), dtype=tf.float32, minimum=0., maximum=1.,
                                                 name='alpha')
      input_tensor_spec = (observation_spec, action_spec, alpha_spec)
      critic_net = agents.DistributionalCriticNetwork(
        input_tensor_spec, preprocessing_layer_size=critic_preprocessing_layer_size,
        joint_fc_layer_params=critic_joint_fc_layers)
      actor_net = agents.WcpgActorNetwork((observation_spec, alpha_spec), action_spec)
    else:
      actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)
      critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
      logging.debug('Making SQRL agent')
      if lambda_schedule_nsteps > 0:
        lambda_update_every_nsteps = num_global_steps // lambda_schedule_nsteps
        step_size = (lambda_final - lambda_initial) / lambda_update_every_nsteps
        lambda_scheduler = lambda lam: common.periodically(
          body=lambda: tf.group(lam.assign(lam + step_size)),
          period=lambda_update_every_nsteps)
      else:
        lambda_scheduler = None
      safety_critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)
      ts = target_safety
      thresholds = [ts, 0.5]
      sc_metrics = [tf.keras.metrics.AUC(name='safety_critic_auc'),
                    tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                                   thresholds=thresholds),
                    tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                                    thresholds=thresholds),
                    tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                                   thresholds=thresholds),
                    tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                                    thresholds=thresholds),
                    tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                                    threshold=0.5)]
      tf_agent = agent_class(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        safety_critic_network=safety_critic_net,
        train_step_counter=global_step,
        debug_summaries=debug_summaries,
        safety_pretraining=pretraining,
        train_critic_online=online_critic,
        initial_log_lambda=lambda_initial,
        log_lambda=(lambda_scheduler is None),
        lambda_scheduler=lambda_scheduler)
    elif agent_class is ensemble_sac_agent.EnsembleSacAgent:
      critic_nets, critic_optimizers = [critic_net], [tf.keras.optimizers.Adam(critic_learning_rate)]
      for _ in range(num_critics - 1):
        critic_nets.append(agents.CriticNetwork((observation_spec, action_spec),
                                                joint_fc_layer_params=critic_joint_fc_layers))
        critic_optimizers.append(tf.keras.optimizers.Adam(critic_learning_rate))
      tf_agent = agent_class(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_networks=critic_nets,
        critic_optimizers=critic_optimizers,
        debug_summaries=debug_summaries
      )
    else:  # agent is either SacAgent or WcpgAgent
      logging.debug('critic input_tensor_spec: %s', critic_net.input_tensor_spec)
      tf_agent = agent_class(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        train_step_counter=global_step,
        debug_summaries=debug_summaries)

    tf_agent.initialize()

    # =====================================================================#
    #  Setup replay buffer                                                 #
    # =====================================================================#
    collect_data_spec = tf_agent.collect_data_spec

    logging.debug('Allocating replay buffer ...')
    # Add to replay buffer and other agent specific observers.
    rb_size = rb_size or 1000000
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      collect_data_spec,
      batch_size=1,
      max_length=rb_size)

    logging.debug('RB capacity: %i', replay_buffer.capacity)
    logging.debug('ReplayBuffer Collect data spec: %s', collect_data_spec)

    if agent_class in SAFETY_AGENTS:
      sc_rb_size = sc_rb_size or num_eval_episodes * 500
      sc_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=sc_rb_size,
        dataset_window_shift=1)

    num_episodes = tf_metrics.NumberOfEpisodes()
    num_env_steps = tf_metrics.EnvironmentSteps()
    return_metric = tf_metrics.AverageReturnMetric(
      buffer_size=num_eval_episodes, batch_size=tf_env.batch_size)
    train_metrics = [
                      num_episodes, num_env_steps,
                      return_metric,
                      tf_metrics.AverageEpisodeLengthMetric(
                        buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
                    ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

    if 'Minitaur' in env_name and not pretraining:
      goal_vel = gin.query_parameter("%GOAL_VELOCITY")
      early_termination_fn = train_utils.MinitaurTerminationFn(
        speed_metric=train_metrics[-2], total_falls_metric=train_metrics[-3],
        env_steps_metric=num_env_steps, goal_speed=goal_vel)

    if env_metric_factories:
      for env_metric in env_metric_factories:
        train_metrics.append(tf_py_metric.TFPyMetric(env_metric(tf_env.pyenv.envs)))
        if run_eval:
          eval_metrics.append(env_metric([env for env in
                                          eval_tf_env.pyenv._envs]))

    # =====================================================================#
    #  Setup collect policies                                              #
    # =====================================================================#
    if not online_critic:
      eval_policy = tf_agent.policy
      collect_policy = tf_agent.collect_policy
      if not pretraining and agent_class in SAFETY_AGENTS:
        collect_policy = tf_agent.safe_policy
    else:
      eval_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy
      collect_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy
      online_collect_policy = tf_agent.safe_policy  # if pretraining else tf_agent.collect_policy
      if pretraining:
        online_collect_policy._training = False

    if not load_root_dir:
      initial_collect_policy = random_tf_policy.RandomTFPolicy(time_step_spec, action_spec)
    else:
      initial_collect_policy = collect_policy
    if agent_class == wcpg_agent.WcpgAgent:
      initial_collect_policy = agents.WcpgPolicyWrapper(initial_collect_policy)

    # =====================================================================#
    #  Setup Checkpointing                                                 #
    # =====================================================================#
    train_checkpointer = common.Checkpointer(
      ckpt_dir=train_dir,
      agent=tf_agent,
      global_step=global_step,
      metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
      ckpt_dir=os.path.join(train_dir, 'policy'),
      policy=eval_policy,
      global_step=global_step)

    rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')
    rb_checkpointer = common.Checkpointer(
      ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer)

    if online_critic:
      online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer')
      online_rb_checkpointer = common.Checkpointer(
        ckpt_dir=online_rb_ckpt_dir,
        max_to_keep=1,
        replay_buffer=sc_buffer)

    # loads agent, replay buffer, and online sc/buffer if online_critic
    if load_root_dir:
      load_root_dir = os.path.expanduser(load_root_dir)
      load_train_dir = os.path.join(load_root_dir, 'train')
      misc.load_agent_ckpt(load_train_dir, tf_agent)
      if len(os.listdir(os.path.join(load_train_dir, 'replay_buffer'))) > 1:
        load_rb_ckpt_dir = os.path.join(load_train_dir, 'replay_buffer')
        misc.load_rb_ckpt(load_rb_ckpt_dir, replay_buffer)
      if online_critic:
        load_online_sc_ckpt_dir = os.path.join(load_root_dir, 'sc')
        load_online_rb_ckpt_dir = os.path.join(load_train_dir,
                                               'online_replay_buffer')
        if osp.exists(load_online_rb_ckpt_dir):
          misc.load_rb_ckpt(load_online_rb_ckpt_dir, sc_buffer)
        if osp.exists(load_online_sc_ckpt_dir):
          misc.load_safety_critic_ckpt(load_online_sc_ckpt_dir,
                                       safety_critic_net)
      elif agent_class in SAFETY_AGENTS:
        offline_run = sorted(os.listdir(os.path.join(load_train_dir, 'offline')))[-1]
        load_sc_ckpt_dir = os.path.join(load_train_dir, 'offline',
                                        offline_run, 'safety_critic')
        if osp.exists(load_sc_ckpt_dir):
          sc_net_off = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=(512, 512),
            name='SafetyCriticOffline')
          sc_net_off.create_variables()
          target_sc_net_off = common.maybe_copy_target_network_with_checks(
            sc_net_off, None, 'TargetSafetyCriticNetwork')
          sc_optimizer = tf.keras.optimizers.Adam(critic_learning_rate)
          _ = misc.load_safety_critic_ckpt(
            load_sc_ckpt_dir, safety_critic_net=sc_net_off,
            target_safety_critic=target_sc_net_off,
            optimizer=sc_optimizer)
          tf_agent._safety_critic_network = sc_net_off
          tf_agent._target_safety_critic_network = target_sc_net_off
          tf_agent._safety_critic_optimizer = sc_optimizer
    else:
      train_checkpointer.initialize_or_restore()
      rb_checkpointer.initialize_or_restore()
      if online_critic:
        online_rb_checkpointer.initialize_or_restore()

    if agent_class in SAFETY_AGENTS:
      sc_dir = os.path.join(root_dir, 'sc')
      safety_critic_checkpointer = common.Checkpointer(
        ckpt_dir=sc_dir,
        safety_critic=tf_agent._safety_critic_network,
        # pylint: disable=protected-access
        target_safety_critic=tf_agent._target_safety_critic_network,
        optimizer=tf_agent._safety_critic_optimizer,
        global_step=global_step)

      if not (load_root_dir and not online_critic):
        safety_critic_checkpointer.initialize_or_restore()

    agent_observers = [replay_buffer.add_batch] + train_metrics
    collect_driver = collect_driver_class(
      tf_env, collect_policy, observers=agent_observers)
    collect_driver.run = common.function_in_tf1()(collect_driver.run)

    if online_critic:
      logging.debug('online driver class: %s', online_driver_class)
      online_agent_observers = [num_episodes, num_env_steps,
                                sc_buffer.add_batch]
      online_driver = online_driver_class(
        sc_tf_env, online_collect_policy, observers=online_agent_observers,
        num_episodes=num_eval_episodes)
      online_driver.run = common.function_in_tf1()(online_driver.run)

    if eager_debug:
      tf.config.experimental_run_functions_eagerly(True)
    else:
      config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=True)
      tf.function(config_saver.after_create_session)()

    if global_step == 0:
      logging.info('Performing initial collection ...')
      init_collect_observers = agent_observers
      if agent_class in SAFETY_AGENTS:
        init_collect_observers += [sc_buffer.add_batch]
      initial_collect_driver_class(
        tf_env,
        initial_collect_policy,
        observers=init_collect_observers).run()
      last_id = replay_buffer._get_last_id()  # pylint: disable=protected-access
      logging.info('Data saved after initial collection: %d steps', last_id)
      if agent_class in SAFETY_AGENTS:
        last_id = sc_buffer._get_last_id()  # pylint: disable=protected-access
        logging.debug('Data saved in sc_buffer after initial collection: %d steps', last_id)

    if run_eval:
      results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='EvalMetrics',
      )
      if train_metrics_callback is not None:
        train_metrics_callback(results, global_step.numpy())
      metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    train_step = train_utils.get_train_step(tf_agent, replay_buffer, batch_size)

    if agent_class in SAFETY_AGENTS:
      critic_train_step = train_utils.get_critic_train_step(
        tf_agent, replay_buffer, sc_buffer, batch_size=batch_size,
        updating_sc=updating_sc, metrics=sc_metrics)

    if early_termination_fn is None:
      early_termination_fn = lambda: False

    loss_diverged = False
    # How many consecutive steps was loss diverged for.
    loss_divergence_counter = 0
    mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss')

    if agent_class in SAFETY_AGENTS:
      resample_counter = collect_policy._resample_counter
      mean_resample_ac = tf.keras.metrics.Mean(name='mean_unsafe_ac_freq')
      sc_metrics.append(mean_resample_ac)

      if online_critic:
        logging.debug('starting safety critic pretraining')
        # don't fine-tune safety critic
        if global_step.numpy() == 0:
          for _ in range(train_sc_steps):
            sc_loss, lambda_loss = critic_train_step()
          critic_results = [('sc_loss', sc_loss.numpy()), ('lambda_loss', lambda_loss.numpy())]
          for critic_metric in sc_metrics:
            res = critic_metric.result().numpy()
            if not res.shape:
              critic_results.append((critic_metric.name, res))
            else:
              for r, thresh in zip(res, thresholds):
                name = '_'.join([critic_metric.name, str(thresh)])
                critic_results.append((name, r))
            critic_metric.reset_states()
          if train_metrics_callback:
            train_metrics_callback(collections.OrderedDict(critic_results),
                                   step=global_step.numpy())

    logging.debug('Starting main train loop...')
    curr_ep = []
    global_step_val = global_step.numpy()
    while global_step_val <= num_global_steps and not early_termination_fn():
      start_time = time.time()

      # MEASURE ACTION RESAMPLING FREQUENCY
      if agent_class in SAFETY_AGENTS:
        if pretraining and global_step_val == num_global_steps // 2:
          if online_critic:
            online_collect_policy._training = True
          collect_policy._training = True
        if online_critic or collect_policy._training:
          mean_resample_ac(resample_counter.result())
          resample_counter.reset()
          if time_step is None or time_step.is_last():
            resample_ac_freq = mean_resample_ac.result()
            mean_resample_ac.reset_states()
            tf.compat.v2.summary.scalar(
              name='resample_ac_freq', data=resample_ac_freq, step=global_step)

      # RUN COLLECTION
      time_step, policy_state = collect_driver.run(
        time_step=time_step,
        policy_state=policy_state,
      )

      # get last step taken by step_driver
      traj = replay_buffer._data_table.read(replay_buffer._get_last_id() %
                                            replay_buffer._capacity)
      curr_ep.append(traj)

      if time_step.is_last():
        if agent_class in SAFETY_AGENTS:
          if time_step.observation['task_agn_rew']:
            if kstep_fail:
              # applies task agn rew. over last k steps
              for i, traj in enumerate(curr_ep[-kstep_fail:]):
                traj.observation['task_agn_rew'] = 1.
                sc_buffer.add_batch(traj)
            else:
              [sc_buffer.add_batch(traj) for traj in curr_ep]
        curr_ep = []
        if agent_class == wcpg_agent.WcpgAgent:
          collect_policy._alpha = None  # reset WCPG alpha

      if (global_step_val + 1) % log_interval == 0:
        logging.debug('policy eval: %4.2f sec', time.time() - start_time)

      # PERFORMS TRAIN STEP ON ALGORITHM (OFF-POLICY)
      for _ in range(train_steps_per_iteration):
        train_loss = train_step()
        mean_train_loss(train_loss.loss)

      current_step = global_step.numpy()
      total_loss = mean_train_loss.result()
      mean_train_loss.reset_states()

      if train_metrics_callback and current_step % summary_interval == 0:
        train_metrics_callback(
          collections.OrderedDict([(k, v.numpy()) for k, v in
                                   train_loss.extra._asdict().items()]),
          step=current_step)
        train_metrics_callback(
          {'train_loss': total_loss.numpy()}, step=current_step)

      # TRAIN AND/OR EVAL SAFETY CRITIC
      if agent_class in SAFETY_AGENTS and current_step % train_sc_interval == 0:
        if online_critic:
          batch_time_step = sc_tf_env.reset()

          # run online critic training collect & update
          batch_policy_state = online_collect_policy.get_initial_state(
            sc_tf_env.batch_size)
          online_driver.run(time_step=batch_time_step,
                            policy_state=batch_policy_state)
        for _ in range(train_sc_steps):
          sc_loss, lambda_loss = critic_train_step()
        # log safety_critic loss results
        critic_results = [('sc_loss', sc_loss.numpy()),
                          ('lambda_loss', lambda_loss.numpy())]
        metric_utils.log_metrics(sc_metrics)
        for critic_metric in sc_metrics:
          res = critic_metric.result().numpy()
          if not res.shape:
            critic_results.append((critic_metric.name, res))
          else:
            for r, thresh in zip(res, thresholds):
              name = '_'.join([critic_metric.name, str(thresh)])
              critic_results.append((name, r))
          critic_metric.reset_states()
        if train_metrics_callback and current_step % summary_interval == 0:
          train_metrics_callback(collections.OrderedDict(critic_results),
                                 step=current_step)

      # Check for exploding losses.
      if (math.isnan(total_loss) or math.isinf(total_loss) or
              total_loss > MAX_LOSS):
        loss_divergence_counter += 1
        if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
          loss_diverged = True
          logging.info('Loss diverged, critic_loss: %s, actor_loss: %s',
                       train_loss.extra.critic_loss,
                       train_loss.extra.actor_loss)
          break
      else:
        loss_divergence_counter = 0

      time_acc += time.time() - start_time

      # LOGGING AND METRICS
      if current_step % log_interval == 0:
        metric_utils.log_metrics(train_metrics)
        logging.info('step = %d, loss = %f', current_step, total_loss)
        steps_per_sec = (current_step - timed_at_step) / time_acc
        logging.info('%4.2f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = current_step
        time_acc = 0

      train_results = []

      for metric in train_metrics[2:]:
        if isinstance(metric, (metrics.AverageEarlyFailureMetric,
                               metrics.AverageFallenMetric,
                               metrics.AverageSuccessMetric)):
          # Plot failure as a fn of return
          metric.tf_summaries(
            train_step=global_step, step_metrics=[num_env_steps, num_episodes,
                                                  return_metric])
        else:
          metric.tf_summaries(
            train_step=global_step, step_metrics=[num_env_steps, num_env_steps])
        train_results.append((metric.name, metric.result().numpy()))

      if train_metrics_callback and current_step % summary_interval == 0:
        train_metrics_callback(collections.OrderedDict(train_results),
                               step=global_step.numpy())

      if current_step % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=current_step)

      if current_step % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=current_step)
        if agent_class in SAFETY_AGENTS:
          safety_critic_checkpointer.save(global_step=current_step)
          if online_critic:
            online_rb_checkpointer.save(global_step=current_step)

      if rb_checkpoint_interval and current_step % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=current_step)

      if wandb and current_step % eval_interval == 0 and "Drunk" in env_name:
        misc.record_point_mass_episode(eval_tf_env, eval_policy, current_step)
        if online_critic:
          misc.record_point_mass_episode(eval_tf_env, tf_agent.safe_policy,
                                         current_step, 'safe-trajectory')

      if run_eval and current_step % eval_interval == 0:
        eval_results = metric_utils.eager_compute(
          eval_metrics,
          eval_tf_env,
          eval_policy,
          num_episodes=num_eval_episodes,
          train_step=global_step,
          summary_writer=eval_summary_writer,
          summary_prefix='EvalMetrics',
        )
        if train_metrics_callback is not None:
          train_metrics_callback(eval_results, current_step)
        metric_utils.log_metrics(eval_metrics)

        with eval_summary_writer.as_default():
          for eval_metric in eval_metrics[2:]:
            eval_metric.tf_summaries(train_step=global_step,
                                     step_metrics=eval_metrics[:2])

      if monitor and current_step % monitor_interval == 0:
        monitor_time_step = monitor_py_env.reset()
        monitor_policy_state = eval_policy.get_initial_state(1)
        ep_len = 0
        monitor_start = time.time()
        while not monitor_time_step.is_last():
          monitor_action = eval_policy.action(monitor_time_step, monitor_policy_state)
          action, monitor_policy_state = monitor_action.action, monitor_action.state
          monitor_time_step = monitor_py_env.step(action)
          ep_len += 1
        logging.debug('saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                      current_step, ep_len, time.time() - monitor_start)

      global_step_val = current_step

  if early_termination_fn():
    #  Early stopped, save all checkpoints if not saved
    if global_step_val % train_checkpoint_interval != 0:
      train_checkpointer.save(global_step=global_step_val)

    if global_step_val % policy_checkpoint_interval != 0:
      policy_checkpointer.save(global_step=global_step_val)
      if agent_class in SAFETY_AGENTS:
        safety_critic_checkpointer.save(global_step=global_step_val)
        if online_critic:
          online_rb_checkpointer.save(global_step=global_step_val)

    if rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0:
      rb_checkpointer.save(global_step=global_step_val)

  if not keep_rb_checkpoint:
    misc.cleanup_checkpoints(rb_ckpt_dir)

  if loss_diverged:
    # Raise an error at the very end after the cleanup.
    raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
      total_loss, global_step.numpy()))

  return total_loss