def main(_):

    # environment serves as the dataset in reinforcement learning
    train_env = tf_py_environment.TFPyEnvironment(
        ParallelPyEnvironment([lambda: suite_mujoco.load('HalfCheetah-v2')] *
                              batch_size))
    eval_env = tf_py_environment.TFPyEnvironment(
        suite_mujoco.load('HalfCheetah-v2'))
    # create agent
    actor_net = ActorDistributionRnnNetwork(train_env.observation_spec(),
                                            train_env.action_spec(),
                                            lstm_size=(100, 100))
    value_net = ValueRnnNetwork(train_env.observation_spec())
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3)
    tf_agent = ppo_agent.PPOAgent(train_env.time_step_spec(),
                                  train_env.action_spec(),
                                  optimizer=optimizer,
                                  actor_net=actor_net,
                                  value_net=value_net,
                                  normalize_observations=False,
                                  normalize_rewards=False,
                                  use_gae=True,
                                  num_epochs=25)
    tf_agent.initialize()
    # replay buffer
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=1000000)
    # policy saver
    saver = policy_saver.PolicySaver(tf_agent.policy)
    # define trajectory collector
    train_episode_count = tf_metrics.NumberOfEpisodes()
    train_total_steps = tf_metrics.EnvironmentSteps()
    train_avg_reward = tf_metrics.AverageReturnMetric(
        batch_size=train_env.batch_size)
    train_avg_episode_len = tf_metrics.AverageEpisodeLengthMetric(
        batch_size=train_env.batch_size)
    train_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        train_env,
        tf_agent.collect_policy,  # NOTE: use PPOPolicy to collect episode
        observers=[
            replay_buffer.add_batch, train_episode_count, train_total_steps,
            train_avg_reward, train_avg_episode_len
        ],  # callbacks when an episode is completely collected
        num_episodes=30,  # how many episodes are collected in an iteration
    )
    # training
    eval_avg_reward = tf_metrics.AverageReturnMetric(buffer_size=30)
    eval_avg_episode_len = tf_metrics.AverageEpisodeLengthMetric(
        buffer_size=30)
    while train_total_steps.result() < 25000000:
        train_driver.run()
        trajectories = replay_buffer.gather_all()
        loss, _ = tf_agent.train(experience=trajectories)
        replay_buffer.clear()
        # clear collected episodes right after training
        if tf_agent.train_step_counter.numpy() % 50 == 0:
            print('step = {0}: loss = {1}'.format(
                tf_agent.train_step_counter.numpy(), loss))
        if tf_agent.train_step_counter.numpy() % 500 == 0:
            # save checkpoint
            saver.save('checkpoints/policy_%d' %
                       tf_agent.train_step_counter.numpy())
            # evaluate the updated policy
            eval_avg_reward.reset()
            eval_avg_episode_len.reset()
            eval_driver = dynamic_episode_driver.DynamicEpisodeDriver(
                eval_env,
                tf_agent.policy,
                observers=[
                    eval_avg_reward,
                    eval_avg_episode_len,
                ],
                num_episodes=
                30,  # how many epsiodes are collected in an iteration
            )
            eval_driver.run()
            print(
                'step = {0}: Average Return = {1} Average Episode Length = {2}'
                .format(tf_agent.train_step_counter.numpy(),
                        train_avg_reward.result(),
                        train_avg_episode_len.result()))
    # play cartpole for the last 3 times and visualize
    import cv2
    for _ in range(3):
        status = eval_env.reset()
        policy_state = tf_agent.policy.get_initial_state(eval_env.batch_size)
        while not status.is_last():
            action = tf_agent.policy.action(status, policy_state)
            # NOTE: use greedy policy to test
            status = eval_env.step(action.action)
            policy_state = action.state
            cv2.imshow('halfcheetah', eval_env.pyenv.envs[0].render())
            cv2.waitKey(25)
示例#2
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_whitelist='position',
        eval_env_name=None,
        num_iterations=1000000,
        # Params for networks.
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        num_parallel_environments=1,
        # Params for collect
        initial_collect_episodes=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        train_sequence_length=20,
        critic_learning_rate=3e-4,
        actor_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=_DEFAULT_REWARD_SCALE,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for RNN SAC on DM control."""
    root_dir = os.path.expanduser(root_dir)

    if reward_scale_factor == _DEFAULT_REWARD_SCALE:
        # Use value recommended by https://arxiv.org/abs/1801.01290
        if env_name.startswith('Humanoid'):
            reward_scale_factor = 20.0
        else:
            reward_scale_factor = 5.0

    root_dir = os.path.expanduser(root_dir)

    summary_writer = tf.compat.v2.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_whitelist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_whitelist=[observations_whitelist])
            ]
        else:
            env_wrappers = []

        env_load_fn = functools.partial(suite_dm_control.load,
                                        task_name=task_name,
                                        env_wrappers=env_wrappers)

        if num_parallel_environments == 1:
            py_env = env_load_fn(env_name)
        else:
            py_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers,
            continuous_projection_net=normal_projection_net)

        critic_net = critic_rnn_network.CriticRnnNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size * num_parallel_environments,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        env_steps = tf_metrics.EnvironmentSteps(prefix='Train')
        average_return = tf_metrics.AverageReturnMetric(
            prefix='Train',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(prefix='Train'),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='Train',
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=initial_collect_episodes)

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps'
                'with a random policy.', initial_collect_episodes)
            initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=env_steps.result(),
            summary_writer=summary_writer,
            summary_prefix='Eval',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, env_steps.result())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            start_env_steps = env_steps.result()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            episode_steps = env_steps.result() - start_env_steps
            for _ in range(episode_steps):
                for _ in range(train_steps_per_iteration):
                    train_step()
                time_acc += time.time() - start_time

                if global_step.numpy() % log_interval == 0:
                    logging.info('env steps = %d, average return = %f',
                                 env_steps.result(), average_return.result())
                    env_steps_per_sec = (env_steps.result().numpy() -
                                         env_steps_before) / time_acc
                    logging.info('%.3f env steps/sec', env_steps_per_sec)
                    tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                                data=env_steps_per_sec,
                                                step=env_steps.result())
                    time_acc = 0
                    env_steps_before = env_steps.result().numpy()

                for train_metric in train_metrics:
                    train_metric.tf_summaries(train_step=env_steps.result())

                if global_step.numpy() % eval_interval == 0:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    if eval_metrics_callback is not None:
                        eval_metrics_callback(results, env_steps.numpy())
                    metric_utils.log_metrics(eval_metrics)

                global_step_val = global_step.numpy()
                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
示例#3
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        num_iterations=2000000,
        actor_fc_layers=(400, 300),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=100000,
        exploration_noise_std=0.1,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        actor_update_period=2,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for checkpoints, summaries, and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for TD3."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_mujoco.load(env_name))

        actor_net = actor_network.ActorNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers,
        )

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_network.CriticNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
        )

        tf_agent = td3_agent.Td3Agent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            exploration_noise_std=exploration_noise_std,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            actor_update_period=actor_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch],
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

        return train_loss
示例#4
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.contrib.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.contrib.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    with tf.contrib.summary.record_summaries_every_n_global_steps(
            summary_interval):

        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        trajectory_spec = trajectory.from_transition(
            time_step=tf_env.time_step_spec(),
            action_step=policy_step.PolicyStep(action=tf_env.action_spec()),
            next_time_step=tf_env.time_step_spec())
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=trajectory_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        q_net = q_network.QNetwork(tf_env.time_step_spec().observation,
                                   tf_env.action_spec(),
                                   fc_layer_params=fc_layer_params)

        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        tf_agent.initialize()
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy()
        collect_policy = tf_agent.collect_policy()

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        global_step = tf.compat.v1.train.get_or_create_global_step()

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        metrics = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(metrics, global_step.numpy())

        time_step = None
        policy_state = ()

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                experience, _ = next(iterator)
                train_loss = tf_agent.train(experience,
                                            train_step_counter=global_step)
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.contrib.summary.scalar(name='global_steps/sec',
                                          tensor=steps_per_sec)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                metrics = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(metrics, global_step.numpy())
        return train_loss
def train_eval(
        root_dir,
        env_name='MultiGrid-Empty-5x5-v0',
        env_load_fn=multiagent_gym_suite.load,
        random_seed=0,
        # Architecture params
        actor_fc_layers=(64, 64),
        value_fc_layers=(64, 64),
        lstm_size=(64, ),
        conv_filters=64,
        conv_kernel=3,
        direction_fc=5,
        entropy_regularization=0.,
        use_attention_networks=False,
        # Specialized agents
        inactive_agent_ids=tuple(),
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=5,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=2,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=2,
        eval_interval=5,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=100,
        log_interval=10,
        summary_interval=10,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=True,
        summarize_grads_and_vars=True,
        eval_metrics_callback=None,
        reinit_checkpoint_dir=None,
        debug=True):
    """A simple train and eval for PPO."""
    tf.compat.v1.enable_v2_behavior()

    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    if debug:
        logging.info('In debug mode, turning tf_functions off')
        use_tf_functions = False

    for a in inactive_agent_ids:
        logging.info('Fixing and not training agent %d', a)

    # Load multiagent gym environment and determine number of agents
    gym_env = env_load_fn(env_name)
    n_agents = gym_env.n_agents

    # Set up logging
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        multiagent_metrics.AverageReturnMetric(n_agents,
                                               buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)

        logging.info('Creating %d environments...', num_parallel_environments)
        wrappers = []
        if use_attention_networks:
            wrappers = [
                lambda env: utils.LSTMStateWrapper(env, lstm_size=lstm_size)
            ]

        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(env_name,
                        gym_kwargs=dict(seed=random_seed),
                        gym_env_wrappers=wrappers))
        # pylint: disable=g-complex-comprehension
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                functools.partial(env_load_fn,
                                  environment_name=env_name,
                                  gym_env_wrappers=wrappers,
                                  gym_kwargs=dict(seed=random_seed * 1234 + i))
                for i in range(num_parallel_environments)
            ]))

        logging.info('Preparing to train...')
        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            multiagent_metrics.AverageReturnMetric(
                n_agents, batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments)
        ]

        logging.info('Creating agent...')
        tf_agent = multiagent_ppo.MultiagentPPO(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            n_agents=n_agents,
            learning_rate=learning_rate,
            actor_fc_layers=actor_fc_layers,
            value_fc_layers=value_fc_layers,
            lstm_size=lstm_size,
            conv_filters=conv_filters,
            conv_kernel=conv_kernel,
            direction_fc=direction_fc,
            entropy_regularization=entropy_regularization,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
            inactive_agent_ids=inactive_agent_ids,
            use_attention_networks=use_attention_networks)
        tf_agent.initialize()
        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        logging.info('Allocating replay buffer ...')
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)
        logging.info('RB capacity: %i', replay_buffer.capacity)

        # If reinit_checkpoint_dir is provided, the last agent in the checkpoint is
        # reinitialized. The other agents are novices.
        # Otherwise, all agents are reinitialized from train_dir.
        if reinit_checkpoint_dir:
            reinit_checkpointer = common.Checkpointer(
                ckpt_dir=reinit_checkpoint_dir,
                agent=tf_agent,
            )
            reinit_checkpointer.initialize_or_restore()
            temp_dir = os.path.join(train_dir, 'tmp')
            agent_checkpointer = common.Checkpointer(
                ckpt_dir=temp_dir,
                agent=tf_agent.agents[:-1],
            )
            agent_checkpointer.save(global_step=0)
            tf_agent = multiagent_ppo.MultiagentPPO(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                n_agents=n_agents,
                learning_rate=learning_rate,
                actor_fc_layers=actor_fc_layers,
                value_fc_layers=value_fc_layers,
                lstm_size=lstm_size,
                conv_filters=conv_filters,
                conv_kernel=conv_kernel,
                direction_fc=direction_fc,
                entropy_regularization=entropy_regularization,
                num_epochs=num_epochs,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
                inactive_agent_ids=inactive_agent_ids,
                non_learning_agents=list(range(n_agents - 1)),
                use_attention_networks=use_attention_networks)
            agent_checkpointer = common.Checkpointer(
                ckpt_dir=temp_dir, agent=tf_agent.agents[:-1])
            agent_checkpointer.initialize_or_restore()
            tf.io.gfile.rmtree(temp_dir)
            eval_policy = tf_agent.policy
            collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=multiagent_metrics.MultiagentMetricsGroup(
                train_metrics, 'train_metrics'))
        if not reinit_checkpoint_dir:
            train_checkpointer.initialize_or_restore()
        logging.info('Successfully initialized train checkpointer')

        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)
        logging.info('Successfully initialized policy saver.')

        print('Using TFDriver')
        if use_attention_networks:
            collect_driver = utils.StateTFDriver(
                tf_env,
                collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                max_episodes=collect_episodes_per_iteration,
                disable_tf_function=not use_tf_functions)
        else:
            collect_driver = tf_driver.TFDriver(
                tf_env,
                collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                max_episodes=collect_episodes_per_iteration,
                disable_tf_function=not use_tf_functions)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0

        # Save operative config as late as possible to include used configurables.
        if global_step.numpy() == 0:
            config_filename = os.path.join(
                train_dir,
                'operative_config-{}.gin'.format(global_step.numpy()))
            with tf.io.gfile.GFile(config_filename, 'wb') as f:
                f.write(gin.operative_config_str())

        total_episodes = 0
        logging.info('Commencing train loop!')
        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()

            # Evaluation
            if global_step_val % eval_interval == 0:
                if debug:
                    logging.info('Performing evaluation at step %d',
                                 global_step_val)
                results = multiagent_metrics.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                    use_function=use_tf_functions,
                    use_attention_networks=use_attention_networks)
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                multiagent_metrics.log_metrics(eval_metrics)

            # Collect data
            if debug:
                logging.info('Collecting at step %d', global_step_val)
            start_time = time.time()
            time_step = tf_env.reset()
            policy_state = collect_policy.get_initial_state(tf_env.batch_size)
            if use_attention_networks:
                # Attention networks require previous policy state to compute attention
                # weights.
                time_step.observation['policy_state'] = (
                    policy_state['actor_network_state'][0],
                    policy_state['actor_network_state'][1])
            collect_driver.run(time_step, policy_state)
            collect_time += time.time() - start_time

            total_episodes += collect_episodes_per_iteration
            if debug:
                logging.info('Have collected a total of %d episodes',
                             total_episodes)

            # Train
            if debug:
                logging.info('Training at step %d', global_step_val)
            start_time = time.time()
            total_loss, extra_loss = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    logging.info(
                        'Loss diverged for too many timesteps, breaking...')
                    break
            else:
                loss_divergence_counter = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, total loss = %f', global_step_val,
                             total_loss)
                for a in range(n_agents):
                    if not inactive_agent_ids or a not in inactive_agent_ids:
                        logging.info('Loss for agent %d = %f', a,
                                     extra_loss[a].loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        results = multiagent_metrics.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
            use_function=use_tf_functions,
            use_attention_networks=use_attention_networks)
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        multiagent_metrics.log_metrics(eval_metrics)
def train_eval(
    root_dir,
    environment_name="broken_reacher",
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    initial_collect_steps=10000,
    real_initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    real_collect_interval=10,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    classifier_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=True,
    summarize_grads_and_vars=False,
    train_on_real=False,
    delta_r_warmup=0,
    random_seed=0,
    checkpoint_dir=None,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, "train")
    eval_dir = os.path.join(root_dir, "eval")

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)

    if environment_name == "broken_reacher":
        get_env_fn = darc_envs.get_broken_reacher_env
    elif environment_name == "half_cheetah_obstacle":
        get_env_fn = darc_envs.get_half_cheetah_direction_env
    elif environment_name.startswith("broken_joint"):
        base_name = environment_name.split("broken_joint_")[1]
        get_env_fn = functools.partial(darc_envs.get_broken_joint_env,
                                       env_name=base_name)
    elif environment_name.startswith("falling"):
        base_name = environment_name.split("falling_")[1]
        get_env_fn = functools.partial(darc_envs.get_falling_env,
                                       env_name=base_name)
    else:
        raise NotImplementedError("Unknown environment: %s" % environment_name)

    eval_name_list = ["sim", "real"]
    eval_env_list = [get_env_fn(mode) for mode in eval_name_list]

    eval_metrics_list = []
    for name in eval_name_list:
        eval_metrics_list.append([
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           name="AverageReturn_%s" % name),
        ])

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env_real = get_env_fn("real")
        if train_on_real:
            tf_env = get_env_fn("real")
        else:
            tf_env = get_env_fn("sim")

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=(
                tanh_normal_projection_network.TanhNormalProjectionNetwork),
        )
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer="glorot_uniform",
            last_kernel_initializer="glorot_uniform",
        )

        classifier = classifiers.build_classifier(observation_spec,
                                                  action_spec)

        tf_agent = darc_agent.DarcAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            classifier=classifier,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            classifier_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=classifier_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
        )
        replay_observer = [replay_buffer.add_batch]

        real_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
        )
        real_replay_observer = [real_replay_buffer.add_batch]

        sim_train_metrics = [
            tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesSim"),
            tf_metrics.EnvironmentSteps(name="EnvironmentStepsSim"),
            tf_metrics.AverageReturnMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageReturnSim",
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageEpisodeLengthSim",
            ),
        ]
        real_train_metrics = [
            tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesReal"),
            tf_metrics.EnvironmentSteps(name="EnvironmentStepsReal"),
            tf_metrics.AverageReturnMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageReturnReal",
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageEpisodeLengthReal",
            ),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(
                sim_train_metrics + real_train_metrics, "train_metrics"),
        )
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "policy"),
            policy=eval_policy,
            global_step=global_step,
        )
        rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "replay_buffer"),
            max_to_keep=1,
            replay_buffer=(replay_buffer, real_replay_buffer),
        )

        if checkpoint_dir is not None:
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
            assert checkpoint_path is not None
            train_checkpointer._load_status = train_checkpointer._checkpoint.restore(  # pylint: disable=protected-access
                checkpoint_path)
            train_checkpointer._load_status.initialize_or_restore()  # pylint: disable=protected-access
        else:
            train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if replay_buffer.num_frames() == 0:
            initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env,
                initial_collect_policy,
                observers=replay_observer + sim_train_metrics,
                num_steps=initial_collect_steps,
            )
            real_initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env_real,
                initial_collect_policy,
                observers=real_replay_observer + real_train_metrics,
                num_steps=real_initial_collect_steps,
            )

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + sim_train_metrics,
            num_steps=collect_steps_per_iteration,
        )

        real_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env_real,
            collect_policy,
            observers=real_replay_observer + real_train_metrics,
            num_steps=collect_steps_per_iteration,
        )

        config_str = gin.operative_config_str()
        logging.info(config_str)
        with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"),
                                     "w") as f:
            f.write(config_str)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            real_initial_collect_driver.run = common.function(
                real_initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            real_collect_driver.run = common.function(real_collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if replay_buffer.num_frames() == 0:
            logging.info(
                "Initializing replay buffer by collecting experience for %d steps with "
                "a random policy.",
                initial_collect_steps,
            )
            initial_collect_driver.run()
            real_initial_collect_driver.run()

        for eval_name, eval_env, eval_metrics in zip(eval_name_list,
                                                     eval_env_list,
                                                     eval_metrics_list):
            metric_utils.eager_compute(
                eval_metrics,
                eval_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix="Metrics-%s" % eval_name,
            )
            metric_utils.log_metrics(eval_metrics)

        time_step = None
        real_time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = (replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5))
        real_dataset = (real_replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5))

        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)
        real_iterator = iter(real_dataset)

        def train_step():
            experience, _ = next(iterator)
            real_experience, _ = next(real_iterator)
            return tf_agent.train(experience, real_experience=real_experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            assert not policy_state  # We expect policy_state == ().
            if (global_step.numpy() % real_collect_interval == 0
                    and global_step.numpy() >= delta_r_warmup):
                real_time_step, policy_state = real_collect_driver.run(
                    time_step=real_time_step,
                    policy_state=policy_state,
                )

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info("step = %d, loss = %f", global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info("%.3f steps/sec", steps_per_sec)
                tf.compat.v2.summary.scalar(name="global_steps_per_sec",
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in sim_train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=sim_train_metrics[:2])
            for train_metric in real_train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=real_train_metrics[:2])

            if global_step_val % eval_interval == 0:
                for eval_name, eval_env, eval_metrics in zip(
                        eval_name_list, eval_env_list, eval_metrics_list):
                    metric_utils.eager_compute(
                        eval_metrics,
                        eval_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix="Metrics-%s" % eval_name,
                    )
                    metric_utils.log_metrics(eval_metrics)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
示例#7
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        num_iterations=1000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]
    eval_summary_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create the environment.
        tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
        eval_py_env = suite_mujoco.load(env_name)

        # Get the data specs from the environment
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=normal_projection_net)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        eval_py_policy = py_tf_policy.PyTFPolicy(
            greedy_policy.GreedyPolicy(tf_agent.policy))

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
            tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
        ]

        collect_policy = tf_agent.collect_policy

        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=5 * batch_size,
            num_steps=2).apply(tf.data.experimental.unbatch()).filter(
                _filter_invalid_transition).batch(batch_size).prefetch(
                    batch_size * 5)
        dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset)
        trajectories, unused_info = dataset_iterator.get_next()
        train_op = tf_agent.train(trajectories)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        with tf.compat.v1.Session() as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)

            # Initialize training.
            sess.run(dataset_iterator.initializer)
            common.initialize_uninitialized_variables(sess)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            global_step_val = sess.run(global_step)

            if global_step_val == 0:
                # Initial eval of randomly initialized policy
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    callback=eval_metrics_callback,
                    log=True,
                )
                sess.run(eval_summary_flush_op)

                # Run initial collect.
                logging.info('Global step %d: Running initial collect op.',
                             global_step_val)
                sess.run(initial_collect_op)

                # Checkpoint the initial replay buffer contents.
                rb_checkpointer.save(global_step=global_step_val)

                logging.info('Finished initial collect.')
            else:
                logging.info('Global step %d: Skipping initial collect op.',
                             global_step_val)

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    total_loss, _ = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_flush_op)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
示例#8
0
def train_eval(
        root_dir,
        env_name='gym_orbital_system:solarsystem-v0',
        eval_env_name=None,
        env_load_fn=suite_gym.load,
        # The SAC paper reported:
        # Hopper and Cartpole results up to 1000000 iters,
        # Humanoid results up to 10000000 iters,
        # Other mujoco tasks up to 3000000 iters.
        num_iterations=3000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
        # HalfCheetah and Ant take 10000 initial collection steps.
        # Other mujoco tasks take 1000.
        # Different choices roughly keep the initial episodes about the same.
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=0.1,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=50000,
        policy_checkpoint_interval=50000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'td3_eval/train')
    eval_dir = os.path.join(root_dir, 'td3_eval/eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        if replay_buffer.num_frames() == 0:
            # Collect initial replay data.
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps '
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        global_step_val = global_step.numpy()
        while global_step_val < num_iterations:
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step_val % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step_val)
                metric_utils.log_metrics(eval_metrics)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
示例#9
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    env_load_fn=suite_mujoco.load,
    random_seed=None,
    # TODO(b/127576522): rename to policy_fc_layers.
    actor_fc_layers=(200, 100),
    value_fc_layers=(200, 100),
    use_rnns=False,
    # Params for collect
    num_environment_steps=25000000,
    collect_episodes_per_iteration=30,
    num_parallel_environments=30,
    replay_buffer_capacity=1001,  # Per-environment
    # Params for train
    num_epochs=25,
    learning_rate=1e-3,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=500,
    # Params for summaries and logging
    train_checkpoint_interval=500,
    policy_checkpoint_interval=500,
    log_interval=50,
    summary_interval=50,
    summaries_flush_secs=1,
    use_tf_functions=True,
    debug_summaries=False,
    summarize_grads_and_vars=False):
  """A simple train and eval for PPO."""
  if root_dir is None:
    raise AttributeError('train_eval requires a root_dir.')

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')
  saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    if random_seed is not None:
      tf.compat.v1.set_random_seed(random_seed)
    eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments))
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

    if use_rnns:
      actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
          tf_env.observation_spec(),
          tf_env.action_spec(),
          input_fc_layer_params=actor_fc_layers,
          output_fc_layer_params=None)
      value_net = value_rnn_network.ValueRnnNetwork(
          tf_env.observation_spec(),
          input_fc_layer_params=value_fc_layers,
          output_fc_layer_params=None)
    else:
      actor_net = actor_distribution_network.ActorDistributionNetwork(
          tf_env.observation_spec(),
          tf_env.action_spec(),
          fc_layer_params=actor_fc_layers,
          activation_fn=tf.keras.activations.tanh)
      value_net = value_network.ValueNetwork(
          tf_env.observation_spec(),
          fc_layer_params=value_fc_layers,
          activation_fn=tf.keras.activations.tanh)

    tf_agent = ppo_clip_agent.PPOClipAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        optimizer,
        actor_net=actor_net,
        value_net=value_net,
        entropy_regularization=0.0,
        importance_ratio_clipping=0.2,
        normalize_observations=False,
        normalize_rewards=False,
        use_gae=True,
        num_epochs=num_epochs,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    environment_steps_metric = tf_metrics.EnvironmentSteps()
    step_metrics = [
        tf_metrics.NumberOfEpisodes(),
        environment_steps_metric,
    ]

    train_metrics = step_metrics + [
        tf_metrics.AverageReturnMetric(
            batch_size=num_parallel_environments),
        tf_metrics.AverageEpisodeLengthMetric(
            batch_size=num_parallel_environments),
    ]

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=num_parallel_environments,
        max_length=replay_buffer_capacity)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    saved_model = policy_saver.PolicySaver(
        eval_policy, train_step=global_step)

    train_checkpointer.initialize_or_restore()

    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_episodes=collect_episodes_per_iteration)

    def train_step():
      trajectories = replay_buffer.gather_all()
      return tf_agent.train(experience=trajectories)

    if use_tf_functions:
      # TODO(b/123828980): Enable once the cause for slowdown was identified.
      collect_driver.run = common.function(collect_driver.run, autograph=False)
      tf_agent.train = common.function(tf_agent.train, autograph=False)
      train_step = common.function(train_step)

    collect_time = 0
    train_time = 0
    timed_at_step = global_step.numpy()

    while environment_steps_metric.result() < num_environment_steps:
      global_step_val = global_step.numpy()
      if global_step_val % eval_interval == 0:
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )

      start_time = time.time()
      collect_driver.run()
      collect_time += time.time() - start_time

      start_time = time.time()
      total_loss, _ = train_step()
      replay_buffer.clear()
      train_time += time.time() - start_time

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=step_metrics)

      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val, total_loss)
        steps_per_sec = (
            (global_step_val - timed_at_step) / (collect_time + train_time))
        logging.info('%.3f steps/sec', steps_per_sec)
        logging.info('collect_time = %.3f, train_time = %.3f', collect_time,
                     train_time)
        with tf.compat.v2.summary.record_if(True):
          tf.compat.v2.summary.scalar(
              name='global_steps_per_sec', data=steps_per_sec, step=global_step)

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)
          saved_model_path = os.path.join(
              saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9))
          saved_model.save(saved_model_path)

        timed_at_step = global_step_val
        collect_time = 0
        train_time = 0

    # One final eval before exiting.
    metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
示例#10
0
def train_eval(
    root_dir,
    env_name='CartPole-v0',
    num_iterations=1000,
    actor_fc_layers=(100,),
    value_net_fc_layers=(100,),
    use_value_network=False,
    use_tf_functions=True,
    # Params for collect
    collect_episodes_per_iteration=2,
    replay_buffer_capacity=2000,
    # Params for train
    learning_rate=1e-3,
    gamma=0.9,
    gradient_clipping=None,
    normalize_returns=True,
    value_estimation_loss_coef=0.2,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=100,
    # Params for checkpoints, summaries, and logging
    log_interval=100,
    summary_interval=100,
    summaries_flush_secs=1,
    debug_summaries=True,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for Reinforce."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
    eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers)

    if use_value_network:
      value_net = value_network.ValueNetwork(
          tf_env.time_step_spec().observation,
          fc_layer_params=value_net_fc_layers)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    tf_agent = reinforce_agent.ReinforceAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        value_network=value_net if use_value_network else None,
        value_estimation_loss_coef=value_estimation_loss_coef,
        gamma=gamma,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
        normalize_returns=normalize_returns,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    tf_agent.initialize()

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_episodes=collect_episodes_per_iteration)

    def train_step():
      experience = replay_buffer.gather_all()
      return tf_agent.train(experience)

    if use_tf_functions:
      # To speed up collect use TF function.
      collect_driver.run = common.function(collect_driver.run)
      # To speed up train use TF function.
      tf_agent.train = common.function(tf_agent.train)
      train_step = common.function(train_step)

    # Compute evaluation metrics.
    metrics = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    # TODO(b/126590894): Move this functionality into eager_compute_summaries
    if eval_metrics_callback is not None:
      eval_metrics_callback(metrics, global_step.numpy())

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      total_loss = train_step()
      replay_buffer.clear()
      time_acc += time.time() - start_time

      global_step_val = global_step.numpy()
      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val, total_loss.loss)
        steps_per_sec = (global_step_val - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step_val
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      if global_step_val % eval_interval == 0:
        metrics = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        # TODO(b/126590894): Move this functionality into
        # eager_compute_summaries.
        if eval_metrics_callback is not None:
          eval_metrics_callback(metrics, global_step_val)
def train_eval_doom_simple(
        videos_dir,
        # Params for collect
        num_environment_steps=30000000,
        collect_episodes_per_iteration=32,
        num_parallel_environments=32,
        replay_buffer_capacity=301,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=4e-4,
        # Params for eval
        eval_interval=500,
        num_video_episodes=10,
        # Params for summaries and logging
        log_interval=50):
    """A simple train and eval for PPO."""
    if not os.path.exists(videos_dir):
        os.makedirs(videos_dir)

    eval_py_env = DoomEnvironment()
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
            [DoomEnvironment] * num_parallel_environments))

    actor_net, value_net = create_networks(tf_env.observation_spec(),
                                           tf_env.action_spec())

    global_step = tf.compat.v1.train.get_or_create_global_step()
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate,
                                                 epsilon=1e-5)

    tf_agent = ppo_agent.PPOAgent(tf_env.time_step_spec(),
                                  tf_env.action_spec(),
                                  optimizer,
                                  actor_net,
                                  value_net,
                                  num_epochs=num_epochs,
                                  train_step_counter=global_step,
                                  discount_factor=0.99,
                                  gradient_clipping=0.5,
                                  entropy_regularization=1e-2,
                                  importance_ratio_clipping=0.2,
                                  use_gae=True,
                                  use_td_lambda_return=True)
    tf_agent.initialize()

    environment_steps_metric = tf_metrics.EnvironmentSteps()
    step_metrics = [
        tf_metrics.NumberOfEpisodes(),
        environment_steps_metric,
    ]

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=num_parallel_environments,
        max_length=replay_buffer_capacity)
    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        tf_agent.collect_policy,
        observers=[replay_buffer.add_batch] + step_metrics,
        num_episodes=collect_episodes_per_iteration)

    def train_step():
        trajectories = replay_buffer.gather_all()
        return tf_agent.train(experience=trajectories)

    def evaluate():
        create_video(eval_py_env,
                     eval_tf_env,
                     tf_agent.policy,
                     num_episodes=num_video_episodes,
                     video_filename=os.path.join(
                         videos_dir, "video_%d.mp4" % global_step_val))

    collect_time = 0
    train_time = 0
    timed_at_step = global_step.numpy()

    while environment_steps_metric.result() < num_environment_steps:

        start_time = time.time()
        collect_driver.run()
        collect_time += time.time() - start_time

        start_time = time.time()
        total_loss, _ = train_step()
        replay_buffer.clear()
        train_time += time.time() - start_time

        global_step_val = global_step.numpy()

        if global_step_val % log_interval == 0:
            logging.info('step = %d, loss = %f', global_step_val, total_loss)
            steps_per_sec = ((global_step_val - timed_at_step) /
                             (collect_time + train_time))
            logging.info('%.3f steps/sec', steps_per_sec)
            logging.info('collect_time = {}, train_time = {}'.format(
                collect_time, train_time))

            timed_at_step = global_step_val
            collect_time = 0
            train_time = 0

        if global_step_val % eval_interval == 0:
            evaluate()

    evaluate()
def train_eval(
    root_dir,
    experiment_name,  # experiment name
    env_name='carla-v0',
    agent_name='sac',  # agent's name
    num_iterations=int(1e7),
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    model_network_ctor_type='non-hierarchical',  # model net
    input_names=['camera', 'lidar'],  # names for inputs
    mask_names=['birdeye'],  # names for masks
    preprocessing_combiner=tf.keras.layers.Add(),  # takes a flat list of tensors and combines them
    actor_lstm_size=(40,),  # lstm size for actor
    critic_lstm_size=(40,),  # lstm size for critic
    actor_output_fc_layers=(100,),  # lstm output
    critic_output_fc_layers=(100,),  # lstm output
    epsilon_greedy=0.1,  # exploration parameter for DQN
    q_learning_rate=1e-3,  # q learning rate for DQN
    ou_stddev=0.2,  # exploration paprameter for DDPG
    ou_damping=0.15,  # exploration parameter for DDPG
    dqda_clipping=None,  # for DDPG
    exploration_noise_std=0.1,  # exploration paramter for td3
    actor_update_period=2,  # for td3
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=int(1e5),
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    initial_model_train_steps=100000,  # initial model training
    batch_size=256,
    model_batch_size=32,  # model training batch size
    sequence_length=4,  # number of timesteps to train model
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    model_learning_rate=1e-4,  # learning rate for model training
    td_errors_loss_fn=tf.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for summaries and logging
    num_images_per_summary=1,  # images for each summary
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    gpu_allow_growth=True,  # GPU memory growth
    gpu_memory_limit=None,  # GPU memory limit
    action_repeat=1):  # Name of single observation channel, ['camera', 'lidar', 'birdeye']
  # Setup GPU
  gpus = tf.config.experimental.list_physical_devices('GPU')
  if gpu_allow_growth:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  if gpu_memory_limit:
    for gpu in gpus:
      tf.config.experimental.set_virtual_device_configuration(
          gpu,
          [tf.config.experimental.VirtualDeviceConfiguration(
              memory_limit=gpu_memory_limit)])

  # Get train and eval direction
  root_dir = os.path.expanduser(root_dir)
  root_dir = os.path.join(root_dir, env_name, experiment_name)

  # Get summary writers
  summary_writer = tf.summary.create_file_writer(
      root_dir, flush_millis=summaries_flush_secs * 1000)
  summary_writer.set_as_default()

  # Eval metrics
  eval_metrics = [
      tf_metrics.AverageReturnMetric(
        name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(
        name='AverageEpisodeLengthEvalPolicy',
        buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()

  # Whether to record for summary
  with tf.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    # Create Carla environment
    if agent_name == 'latent_sac':
      py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names+mask_names, action_repeat=action_repeat)
    elif agent_name == 'dqn':
      py_env, eval_py_env = load_carla_env(env_name='carla-v0', discrete=True, obs_channels=input_names, action_repeat=action_repeat)
    else:
      py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names, action_repeat=action_repeat)

    tf_env = tf_py_environment.TFPyEnvironment(py_env)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
    fps = int(np.round(1.0 / (py_env.dt * action_repeat)))

    # Specs
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    ## Make tf agent
    if agent_name == 'latent_sac':
      # Get model network for latent sac
      if model_network_ctor_type == 'hierarchical':
        model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical
      elif model_network_ctor_type == 'non-hierarchical':
        model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical
      else:
        raise NotImplementedError
      model_net = model_network_ctor(input_names, input_names+mask_names)

      # Get the latent spec
      latent_size = model_net.latent_size
      latent_observation_spec = tensor_spec.TensorSpec((latent_size,), dtype=tf.float32)
      latent_time_step_spec = ts.time_step_spec(observation_spec=latent_observation_spec)

      # Get actor and critic net
      actor_net = actor_distribution_network.ActorDistributionNetwork(
          latent_observation_spec,
          action_spec,
          fc_layer_params=actor_fc_layers,
          continuous_projection_net=normal_projection_net)
      critic_net = critic_network.CriticNetwork(
          (latent_observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers)

      # Build the inner SAC agent based on latent space
      inner_agent = sac_agent.SacAgent(
          latent_time_step_spec,
          action_spec,
          actor_network=actor_net,
          critic_network=critic_net,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=alpha_learning_rate),
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step)
      inner_agent.initialize()

      # Build the latent sac agent
      tf_agent = latent_sac_agent.LatentSACAgent(
          time_step_spec,
          action_spec,
          inner_agent=inner_agent,
          model_network=model_net,
          model_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=model_learning_rate),
          model_batch_size=model_batch_size,
          num_images_per_summary=num_images_per_summary,
          sequence_length=sequence_length,
          gradient_clipping=gradient_clipping,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
          fps=fps)

    else:
      # Set up preprosessing layers for dictionary observation inputs
      preprocessing_layers = collections.OrderedDict()
      for name in input_names:
        preprocessing_layers[name] = Preprocessing_Layer(32,256)
      if len(input_names) < 2:
        preprocessing_combiner = None

      if agent_name == 'dqn':
        q_rnn_net = q_rnn_network.QRnnNetwork(
            observation_spec,
            action_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            input_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers)

        tf_agent = dqn_agent.DqnAgent(
            time_step_spec,
            action_spec,
            q_network=q_rnn_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=1,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=q_learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

      elif agent_name == 'ddpg' or agent_name == 'td3':
        actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork(
          observation_spec,
          action_spec,
          preprocessing_layers=preprocessing_layers,
          preprocessing_combiner=preprocessing_combiner,
          input_fc_layer_params=actor_fc_layers,
          lstm_size=actor_lstm_size,
          output_fc_layer_params=actor_output_fc_layers)

        critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
          (observation_spec, action_spec),
          preprocessing_layers=preprocessing_layers,
          preprocessing_combiner=preprocessing_combiner,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          lstm_size=critic_lstm_size,
          output_fc_layer_params=critic_output_fc_layers)

        if agent_name == 'ddpg':
          tf_agent = ddpg_agent.DdpgAgent(
              time_step_spec,
              action_spec,
              actor_network=actor_rnn_net,
              critic_network=critic_rnn_net,
              actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                  learning_rate=actor_learning_rate),
              critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                  learning_rate=critic_learning_rate),
              ou_stddev=ou_stddev,
              ou_damping=ou_damping,
              target_update_tau=target_update_tau,
              target_update_period=target_update_period,
              dqda_clipping=dqda_clipping,
              td_errors_loss_fn=None,
              gamma=gamma,
              reward_scale_factor=reward_scale_factor,
              gradient_clipping=gradient_clipping,
              debug_summaries=debug_summaries,
              summarize_grads_and_vars=summarize_grads_and_vars)
        elif agent_name == 'td3':
          tf_agent = td3_agent.Td3Agent(
              time_step_spec,
              action_spec,
              actor_network=actor_rnn_net,
              critic_network=critic_rnn_net,
              actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                  learning_rate=actor_learning_rate),
              critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                  learning_rate=critic_learning_rate),
              exploration_noise_std=exploration_noise_std,
              target_update_tau=target_update_tau,
              target_update_period=target_update_period,
              actor_update_period=actor_update_period,
              dqda_clipping=dqda_clipping,
              td_errors_loss_fn=None,
              gamma=gamma,
              reward_scale_factor=reward_scale_factor,
              gradient_clipping=gradient_clipping,
              debug_summaries=debug_summaries,
              summarize_grads_and_vars=summarize_grads_and_vars,
              train_step_counter=global_step)

      elif agent_name == 'sac':
        actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers,
            continuous_projection_net=normal_projection_net)

        critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
          (observation_spec, action_spec),
          preprocessing_layers=preprocessing_layers,
          preprocessing_combiner=preprocessing_combiner,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          lstm_size=critic_lstm_size,
          output_fc_layer_params=critic_output_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_distribution_rnn_net,
            critic_network=critic_rnn_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=tf.math.squared_difference,  # make critic loss dimension compatible
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

      else:
        raise NotImplementedError

    tf_agent.initialize()

    # Get replay buffer
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,  # No parallel environments
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    # Train metrics
    env_steps = tf_metrics.EnvironmentSteps()
    average_return = tf_metrics.AverageReturnMetric(
        buffer_size=num_eval_episodes,
        batch_size=tf_env.batch_size)
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        env_steps,
        average_return,
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size),
    ]

    # Get policies
    # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    eval_policy = tf_agent.policy
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        time_step_spec, action_spec)
    collect_policy = tf_agent.collect_policy

    # Checkpointers
    train_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'train'),
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
        max_to_keep=2)
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step,
        max_to_keep=2)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)
    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    # Collect driver
    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=initial_collect_steps)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)
    

    # Optimize the performance by using tf functions
    initial_collect_driver.run = common.function(initial_collect_driver.run)
    collect_driver.run = common.function(collect_driver.run)
    tf_agent.train = common.function(tf_agent.train)

    # Collect initial replay data.
    if (env_steps.result() == 0 or replay_buffer.num_frames() == 0):
      logging.info(
          'Initializing replay buffer by collecting experience for %d steps'
          'with a random policy.', initial_collect_steps)
      initial_collect_driver.run()

    if agent_name == 'latent_sac':
      compute_summaries(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        train_step=global_step,
        summary_writer=summary_writer,
        num_episodes=1,
        num_episodes_to_render=1,
        model_net=model_net,
        fps=10,
        image_keys=input_names+mask_names)
    else:
      results = metric_utils.eager_compute(
          eval_metrics,
          eval_tf_env,
          eval_policy,
          num_episodes=1,
          train_step=env_steps.result(),
          summary_writer=summary_writer,
          summary_prefix='Eval',
      )
      metric_utils.log_metrics(eval_metrics)

    # Dataset generates trajectories with shape [Bxslx...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=sequence_length + 1).prefetch(3)
    iterator = iter(dataset)

    # Get train step
    def train_step():
      experience, _ = next(iterator)
      return tf_agent.train(experience)
    train_step = common.function(train_step)

    if agent_name == 'latent_sac':
      def train_model_step():
        experience, _ = next(iterator)
        return tf_agent.train_model(experience)
      train_model_step = common.function(train_model_step)

    # Training initializations
    time_step = None
    time_acc = 0
    env_steps_before = env_steps.result().numpy()

    # Start training
    for iteration in range(num_iterations):
      start_time = time.time()

      if agent_name == 'latent_sac' and iteration < initial_model_train_steps:
        train_model_step()
      else:
        # Run collect
        time_step, _ = collect_driver.run(time_step=time_step)

        # Train an iteration
        for _ in range(train_steps_per_iteration):
          train_step()

      time_acc += time.time() - start_time

      # Log training information
      if global_step.numpy() % log_interval == 0:
        logging.info('env steps = %d, average return = %f', env_steps.result(),
                     average_return.result())
        env_steps_per_sec = (env_steps.result().numpy() -
                             env_steps_before) / time_acc
        logging.info('%.3f env steps/sec', env_steps_per_sec)
        tf.summary.scalar(
            name='env_steps_per_sec',
            data=env_steps_per_sec,
            step=env_steps.result())
        time_acc = 0
        env_steps_before = env_steps.result().numpy()

      # Get training metrics
      for train_metric in train_metrics:
        train_metric.tf_summaries(train_step=env_steps.result())

      # Evaluation
      if global_step.numpy() % eval_interval == 0:
        # Log evaluation metrics
        if agent_name == 'latent_sac':
          compute_summaries(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            train_step=global_step,
            summary_writer=summary_writer,
            num_episodes=num_eval_episodes,
            num_episodes_to_render=num_images_per_summary,
            model_net=model_net,
            fps=10,
            image_keys=input_names+mask_names)
        else:
          results = metric_utils.eager_compute(
              eval_metrics,
              eval_tf_env,
              eval_policy,
              num_episodes=num_eval_episodes,
              train_step=env_steps.result(),
              summary_writer=summary_writer,
              summary_prefix='Eval',
          )
          metric_utils.log_metrics(eval_metrics)

      # Save checkpoints
      global_step_val = global_step.numpy()
      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
示例#13
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        num_iterations=2000000,
        actor_fc_layers=(400, 300),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=100000,
        exploration_noise_std=0.1,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        actor_update_period=2,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for TD3."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
        eval_py_env = suite_mujoco.load(env_name)

        actor_net = actor_network.ActorNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers,
        )

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_network.CriticNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
        )

        tf_agent = td3_agent.Td3Agent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            exploration_noise_std=exploration_noise_std,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            actor_update_period=actor_update_period,
            dqda_clipping=dqda_clipping,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        collect_policy = tf_agent.collect_policy
        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        trajectories, unused_info = iterator.get_next()

        train_fn = common.function(tf_agent.train)
        train_op = train_fn(experience=trajectories)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(b/126239733): Remove once Periodically can be saved.
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(
                [train_op, summary_ops, global_step])

            timed_at_step = sess.run(global_step)
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _, global_step_val = train_step_call()
                time_acc += time.time() - start_time

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
def train_eval(
        root_dir,
        env_name='SocialBot-GroceryGround-v0',
        env_load_fn=suite_socialbot.load,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(192, 64),
        value_fc_layers=(192, 64),
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=8,
        num_parallel_environments=8,
        replay_buffer_capacity=2001,  # Per-environment
        # Params for train
    num_epochs=16,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=500,
        # Params for summaries and logging
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """A simple train and eval for GroceryGround."""

    # Set summary writer and eval metrics
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create envs and optimizer
        tf.compat.v1.set_random_seed(random_seed)
        eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)
        # Create actor and value network
        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(), fc_layer_params=value_fc_layers)

        # Create ppo agent
        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Create metrics, replay_buffer and collect_driver
        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]
        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)
        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)
        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        # Evaluate and train
        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            trajectories = replay_buffer.gather_all()
            total_loss, _ = tf_agent.train(experience=trajectories)
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = {}, train_time = {}'.format(
                    collect_time, train_time))
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0
示例#15
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=(100, ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),

        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

        action_spec = tf_env.action_spec()
        num_actions = action_spec.maximum - action_spec.minimum + 1

        if train_sequence_length > 1:
            q_net = create_recurrent_network(input_fc_layer_params, lstm_size,
                                             output_fc_layer_params,
                                             num_actions)
        else:
            q_net = create_feedforward_network(fc_layer_params, num_actions)
            train_sequence_length = n_step_update

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
        return train_loss
示例#16
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=1000,
        # TODO(kbanoop): rename to policy_fc_layers.
        actor_fc_layers=(100, ),
        # Params for collect
        collect_episodes_per_iteration=2,
        replay_buffer_capacity=2000,
        # Params for train
        learning_rate=1e-3,
        gradient_clipping=None,
        normalize_returns=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=100,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=100,
        rb_checkpoint_interval=200,
        log_interval=100,
        summary_interval=100,
        summaries_flush_secs=1,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for Reinforce."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.contrib.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.contrib.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    # TODO(kbanoop): Figure out if it is possible to avoid the with block.
    with tf.contrib.summary.record_summaries_every_n_global_steps(
            summary_interval):

        eval_py_env = suite_gym.load(env_name)
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

        # TODO(kbanoop): Handle distributions without gin.
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers)

        tf_agent = reinforce_agent.ReinforceAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate),
            normalize_returns=normalize_returns,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec(),
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy())

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        global_step = tf.train.get_or_create_global_step()
        collect_policy = tf_agent.collect_policy()

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        experience = replay_buffer.gather_all()
        train_op = tf_agent.train(experience, train_step_counter=global_step)
        clear_rb_op = replay_buffer.clear()

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=tf.contrib.checkpoint.List(train_metrics))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy(),
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(step_metrics=train_metrics[:2])
        summary_op = tf.contrib.summary.all_summary_ops()

        with eval_summary_writer.as_default(), \
             tf.contrib.summary.always_record_summaries():
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            tf.contrib.summary.initialize(session=sess)

            # Compute evaluation metrics.
            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(
                [train_op, summary_op, global_step])
            clear_rb_call = sess.make_callable(clear_rb_op)

            timed_at_step = sess.run(global_step)
            time_acc = 0
            steps_per_second_ph = tf.placeholder(tf.float32,
                                                 shape=(),
                                                 name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                total_loss, _, global_step_val = train_step_call()
                clear_rb_call()
                time_acc += time.time() - start_time

                if global_step_val % log_interval == 0:
                    tf.logging.info('step = %d, loss = %f', global_step_val,
                                    total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    tf.logging.info('%.3f steps/sec' % steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
示例#17
0
def train_eval(
        root_dir,
        env_name='MaskedCartPole-v0',
        num_iterations=100000,
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),
        train_sequence_length=10,
        # Params for collect
        initial_collect_steps=50,
        collect_episodes_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=10,
        batch_size=128,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=100,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        eval_py_env = suite_gym.load(env_name)
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

        q_net = q_rnn_network.QRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=input_fc_layer_params,
            lstm_size=lstm_size,
            output_fc_layer_params=output_fc_layer_params)

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps).run()

        collect_policy = tf_agent.collect_policy
        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        experience, _ = iterator.get_next()
        loss_info = common.function(tf_agent.train)(experience=experience)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=train_metrics[:2])

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            logging.info('Collecting initial experience.')
            sess.run(initial_collect_op)

            # Compute evaluation metrics.
            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(loss_info)
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                # Train/collect/eval.
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        log=True,
                        callback=eval_metrics_callback,
                    )
示例#18
0
def train_eval(
    root_dir,
    env_name="HalfCheetah-v2",
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=100000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=500000000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    relabel_type=None,
    num_future_states=4,
    max_episode_steps=100,
    random_seed=0,
    eval_task_list=None,
    constant_task=None,  # Whether to train on a single task
    clip_critic=None,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    if relabel_type == "none":
        relabel_type = None
    assert relabel_type in [None, "future", "last", "soft", "random"]
    if constant_task:
        assert relabel_type is None
    if eval_task_list is None:
        eval_task_list = []
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, "train")
    eval_dir = os.path.join(root_dir, "eval")

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        utils.AverageSuccessMetric(max_episode_steps=max_episode_steps,
                                   buffer_size=num_eval_episodes),
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env, task_distribution = utils.get_env(env_name,
                                                  constant_task=constant_task)
        eval_tf_env, _ = utils.get_env(env_name,
                                       max_episode_steps,
                                       constant_task=constant_task)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=utils.normal_projection_net,
        )
        if isinstance(clip_critic, float):
            output_activation_fn = lambda x: clip_critic * tf.sigmoid(x)
        elif isinstance(clip_critic, tuple):
            assert len(clip_critic) == 2
            min_val, max_val = clip_critic
            output_activation_fn = (
                lambda x:  # pylint: disable=g-long-lambda
                (max_val - min_val) * tf.sigmoid(x) + min_val)
        else:
            output_activation_fn = None
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            output_activation_fn=output_activation_fn,
        )

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = relabelling_replay_buffer.GoalRelabellingReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
            task_distribution=task_distribution,
            actor=actor_net,
            critic=critic_net,
            gamma=gamma,
            relabel_type=relabel_type,
            sample_batch_size=batch_size,
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
            num_future_states=num_future_states,
        )

        env_steps = tf_metrics.EnvironmentSteps(prefix="Train")
        train_metrics = [
            tf_metrics.NumberOfEpisodes(prefix="Train"),
            env_steps,
            utils.AverageSuccessMetric(
                prefix="Train",
                max_episode_steps=max_episode_steps,
                buffer_size=num_eval_episodes,
            ),
            tf_metrics.AverageReturnMetric(
                prefix="Train",
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                prefix="Train",
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
            ),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, "train_metrics"),
        )
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "policy"),
            policy=eval_policy,
            global_step=global_step,
        )

        train_checkpointer.initialize_or_restore()

        data_collector = utils.DataCollector(
            tf_env,
            tf_agent.collect_policy,
            replay_buffer,
            max_episode_steps=max_episode_steps,
            observers=train_metrics,
        )

        if use_tf_functions:
            tf_agent.train = common.function(tf_agent.train)
        else:
            tf.config.experimental_run_functions_eagerly(True)

        # Save the config string as late as possible to catch
        # as many object instantiations as possible.
        config_str = gin.operative_config_str()
        logging.info(config_str)
        with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"),
                                     "w") as f:
            f.write(config_str)

        # Collect initial replay data.
        logging.info(
            "Initializing replay buffer by collecting experience for %d steps with "
            "a random policy.",
            initial_collect_steps,
        )
        for _ in range(initial_collect_steps):
            data_collector.step(initial_collect_policy)
        data_collector.reset()
        logging.info("Replay buffer initial size: %d",
                     replay_buffer.num_frames())

        logging.info("Computing initial eval metrics")
        for task in [None] + eval_task_list:
            with utils.FixedTask(eval_tf_env, task):
                prefix = "Metrics" if task is None else "Metrics-%s" % str(
                    task)
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix=prefix,
                )
                metric_utils.log_metrics(eval_metrics)

        time_acc = 0
        env_time_acc = 0
        train_time_acc = 0
        env_steps_before = env_steps.result().numpy()

        if use_tf_functions:
            tf_agent.train = common.function(tf_agent.train)

        logging.info("Starting training")
        for _ in range(num_iterations):
            start_time = time.time()
            data_collector.step()
            env_time_acc += time.time() - start_time
            train_time_start = time.time()
            for _ in range(train_steps_per_iteration):
                experience = replay_buffer.get_batch()
                train_loss = tf_agent.train(experience)
                total_loss = train_loss.loss
            train_time_acc += time.time() - train_time_start
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info("step = %d, loss = %f", global_step.numpy(),
                             total_loss)

                combined_steps_per_sec = (env_steps.result().numpy() -
                                          env_steps_before) / time_acc
                train_steps_per_sec = (env_steps.result().numpy() -
                                       env_steps_before) / train_time_acc
                env_steps_per_sec = (env_steps.result().numpy() -
                                     env_steps_before) / env_time_acc
                logging.info(
                    "%.3f combined steps / sec: %.3f env steps/sec, %.3f train steps/sec",
                    combined_steps_per_sec,
                    env_steps_per_sec,
                    train_steps_per_sec,
                )
                tf.compat.v2.summary.scalar(
                    name="combined_steps_per_sec",
                    data=combined_steps_per_sec,
                    step=env_steps.result(),
                )
                tf.compat.v2.summary.scalar(
                    name="env_steps_per_sec",
                    data=env_steps_per_sec,
                    step=env_steps.result(),
                )
                tf.compat.v2.summary.scalar(
                    name="train_steps_per_sec",
                    data=train_steps_per_sec,
                    step=env_steps.result(),
                )
                time_acc = 0
                env_time_acc = 0
                train_time_acc = 0
                env_steps_before = env_steps.result().numpy()

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:

                for task in [None] + eval_task_list:
                    with utils.FixedTask(eval_tf_env, task):
                        prefix = "Metrics" if task is None else "Metrics-%s" % str(
                            task)
                        logging.info(prefix)
                        metric_utils.eager_compute(
                            eval_metrics,
                            eval_tf_env,
                            eval_policy,
                            num_episodes=num_eval_episodes,
                            train_step=global_step,
                            summary_writer=eval_summary_writer,
                            summary_prefix=prefix,
                        )
                        metric_utils.log_metrics(eval_metrics)

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

        return train_loss
示例#19
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=1000,
        # TODO(kbanoop): rename to policy_fc_layers.
        actor_fc_layers=(100, ),
        # Params for collect
        collect_episodes_per_iteration=2,
        replay_buffer_capacity=2000,
        # Params for train
        learning_rate=1e-3,
        gradient_clipping=None,
        normalize_returns=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=100,
        # Params for checkpoints, summaries, and logging
        log_interval=100,
        summary_interval=100,
        summaries_flush_secs=1,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for Reinforce."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.contrib.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.contrib.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    # TODO(kbanoop): Figure out if it is possible to avoid the with block.
    with tf.contrib.summary.record_summaries_every_n_global_steps(
            summary_interval):

        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        # TODO(kbanoop): Handle distributions without gin.
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers)

        tf_agent = reinforce_agent.ReinforceAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate),
            normalize_returns=normalize_returns,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec(),
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        global_step = tf.train.get_or_create_global_step()

        eval_policy = tf_agent.policy()
        collect_policy = tf_agent.collect_policy()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        # Compute evaluation metrics.
        metrics = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        # TODO(sfishman): Move this functionality into eager_compute_summaries
        if eval_metrics_callback is not None:
            eval_metrics_callback(metrics, global_step.numpy())

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            experience = replay_buffer.gather_all()
            total_loss = tf_agent.train(experience,
                                        train_step_counter=global_step)
            replay_buffer.clear()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()
            if global_step_val % log_interval == 0:
                tf.logging.info('step = %d, loss = %f', global_step_val,
                                total_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                tf.logging.info('%.3f steps/sec' % steps_per_sec)
                tf.contrib.summary.scalar(name='global_steps/sec',
                                          tensor=steps_per_sec)
                timed_at_step = global_step_val
                time_acc = 0

            if global_step_val % eval_interval == 0:
                metrics = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                # TODO(sfishman): Move this functionality into eager_compute_summaries
                if eval_metrics_callback is not None:
                    eval_metrics_callback(metrics, global_step_val)
示例#20
0
def train_eval(
        root_dir,
        tf_master='',
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=None,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-3,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
        batched_py_metric.BatchedPyMetric(
            AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
    ]
    eval_summary_writer_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                lstm_size=(40, ),
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

        tf_agent = ppo_clip_agent.PPOClipAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=0.0,
            importance_ratio_clipping=0.2,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        environment_steps_count = environment_steps_metric.result()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        # Add to replay buffer and other agent specific observers.
        replay_buffer_observer = [replay_buffer.add_batch]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_buffer_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        trajectories = replay_buffer.gather_all()

        train_op, _ = tf_agent.train(experience=trajectories)

        with tf.control_dependencies([train_op]):
            clear_replay_op = replay_buffer.clear()

        with tf.control_dependencies([clear_replay_op]):
            train_op = tf.identity(train_op)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step,
                                         step_metrics=step_metrics)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session(tf_master) as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            collect_time = 0
            train_time = 0
            timed_at_step = sess.run(global_step)
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            while sess.run(environment_steps_count) < num_environment_steps:
                global_step_val = sess.run(global_step)
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_writer_flush_op)

                start_time = time.time()
                sess.run(collect_op)
                collect_time += time.time() - start_time
                start_time = time.time()
                total_loss, _ = sess.run([train_op, summary_ops])
                train_time += time.time() - start_time

                global_step_val = sess.run(global_step)
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info(
                        '%s', 'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

            # One final eval before exiting.
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )
            sess.run(eval_summary_writer_flush_op)
示例#21
0
def train_eval(
        root_dir,
        gpu='1',
        env_load_fn=None,
        model_ids=None,
        eval_env_mode='headless',
        conv_layer_params=None,
        encoder_fc_layers=[256],
        actor_fc_layers=[256, 256],
        value_fc_layers=[256, 256],
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
        num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        eval_only=False,
        eval_deterministic=False,
        num_parallel_environments_eval=1,
        model_ids_eval=None,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        rb_checkpoint_interval=500,
        log_interval=10,
        summary_interval=50,
        summaries_flush_secs=1,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            py_metrics.AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments_eval),
        batched_py_metric.BatchedPyMetric(
            py_metrics.AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments_eval),
    ]
    eval_summary_writer_flush_op = eval_summary_writer.flush()
    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if model_ids is None:
            model_ids = [None] * num_parallel_environments
        else:
            assert len(model_ids) == num_parallel_environments,\
                'model ids provided, but length not equal to num_parallel_environments'

        if model_ids_eval is None:
            model_ids_eval = [None] * num_parallel_environments_eval
        else:
            assert len(model_ids_eval) == num_parallel_environments_eval,\
                'model ids eval provided, but length not equal to num_parallel_environments_eval'

        tf_py_env = [lambda model_id=model_ids[i]: env_load_fn(model_id, 'headless', gpu)
                     for i in range(num_parallel_environments)]
        tf_env = tf_py_environment.TFPyEnvironment(parallel_py_environment.ParallelPyEnvironment(tf_py_env))

        if eval_env_mode == 'gui':
            assert num_parallel_environments_eval == 1, 'only one GUI env is allowed'
        eval_py_env = [lambda model_id=model_ids_eval[i]: env_load_fn(model_id, eval_env_mode, gpu)
                       for i in range(num_parallel_environments_eval)]
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(eval_py_env)

        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = tf_env.observation_spec()
        action_spec = tf_env.action_spec()
        print('observation_spec', observation_spec)
        print('action_spec', action_spec)

        glorot_uniform_initializer = tf.compat.v1.keras.initializers.glorot_uniform()
        preprocessing_layers = {
            'depth_seg': tf.keras.Sequential(mlp_layers(
                conv_layer_params=conv_layer_params,
                fc_layer_params=encoder_fc_layers,
                kernel_initializer=glorot_uniform_initializer,
            )),
            'sensor': tf.keras.Sequential(mlp_layers(
                conv_layer_params=None,
                fc_layer_params=encoder_fc_layers,
                kernel_initializer=glorot_uniform_initializer,
            )),
        }
        preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                observation_spec,
                action_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                observation_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                observation_spec,
                action_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                fc_layer_params=actor_fc_layers,
                kernel_initializer=glorot_uniform_initializer
            )
            value_net = value_network.ValueNetwork(
                observation_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                fc_layer_params=value_fc_layers,
                kernel_initializer=glorot_uniform_initializer
            )

        tf_agent = ppo_agent.PPOAgent(
            time_step_spec,
            action_spec,
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.compat.v1.Session(config=config)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        if eval_deterministic:
            eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)
        else:
            eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.collect_policy)

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        environment_steps_count = environment_steps_metric.result()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                buffer_size=100,
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=100,
                batch_size=num_parallel_environments),
        ]

        # Add to replay buffer and other agent specific observers.
        replay_buffer_observer = [replay_buffer.add_batch]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_buffer_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration * num_parallel_environments).run()

        trajectories = replay_buffer.gather_all()

        train_op, _ = tf_agent.train(experience=trajectories)

        with tf.control_dependencies([train_op]):
            clear_replay_op = replay_buffer.clear()

        with tf.control_dependencies([clear_replay_op]):
            train_op = tf.identity(train_op)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy,
            global_step=global_step)
        rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(train_metric.tf_summaries(
                train_step=global_step, step_metrics=step_metrics))

        with eval_summary_writer.as_default(), tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(
                    train_step=global_step, step_metrics=step_metrics)

        init_agent_op = tf_agent.initialize()

        with sess.as_default():
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)

            if eval_only:
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=0,
                    callback=eval_metrics_callback,
                    tf_summaries=False,
                    log=True,
                )
                episodes = eval_py_env.get_stored_episodes()
                episodes = [episode for sublist in episodes for episode in sublist][:num_eval_episodes]
                metrics = episode_utils.get_metrics(episodes)
                for key in sorted(metrics.keys()):
                    print(key, ':', metrics[key])

                save_path = os.path.join(eval_dir, 'episodes_eval.pkl')
                episode_utils.save(episodes, save_path)
                print('EVAL DONE')
                return

            common.initialize_uninitialized_variables(sess)
            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            collect_time = 0
            train_time = 0
            timed_at_step = sess.run(global_step)
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec', data=steps_per_second_ph,
                step=global_step)

            global_step_val = sess.run(global_step)
            while sess.run(environment_steps_count) < num_environment_steps:
                global_step_val = sess.run(global_step)
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    with eval_summary_writer.as_default(), tf.compat.v2.summary.record_if(True):
                        with tf.name_scope('Metrics/'):
                            episodes = eval_py_env.get_stored_episodes()
                            episodes = [episode for sublist in episodes for episode in sublist][:num_eval_episodes]
                            metrics = episode_utils.get_metrics(episodes)
                            for key in sorted(metrics.keys()):
                                print(key, ':', metrics[key])
                                metric_op = tf.compat.v2.summary.scalar(name=key,
                                                                        data=metrics[key],
                                                                        step=global_step_val)
                                sess.run(metric_op)
                    sess.run(eval_summary_writer_flush_op)

                start_time = time.time()
                sess.run(collect_op)
                collect_time += time.time() - start_time
                start_time = time.time()
                total_loss, _ = sess.run([train_op, summary_ops])
                train_time += time.time() - start_time

                global_step_val = sess.run(global_step)
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val, total_loss)
                    steps_per_sec = (
                            (global_step_val - timed_at_step) / (collect_time + train_time))
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(
                        steps_per_second_summary,
                        feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info('%s', 'collect_time = {}, train_time = {}'.format(
                        collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

            # One final eval before exiting.
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )
            sess.run(eval_summary_writer_flush_op)

        sess.close()
示例#22
0
def train_eval(
        root_dir,
        env_name=ENV_NAME,
        num_iterations=ITERATIONS,
        fc_layer_params=LAYER_PARAMETERS,
        # Parameters for collect
        initial_collect_steps=COLLECT_STEPS,
        collect_steps_per_iteration=COLLECT_STEPS,
        epsilon_greedy=GREEDY,
        replay_buffer_capacity=BUFFER,
        # Parameters for target update
        target_update_tau=TAU,
        target_update_period=UPDATE_PERIOD,
        # Parameters for train
        train_steps_per_iteration=TRAIN_ITERATIONS,
        batch_size=BATCH_SIZE,
        learning_rate=LEARN_RATE,
        gamma=GAMMA,
        reward_scale_factor=SCALE,
        gradient_clipping=GRADIENT,
        # Parameters for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Parameters for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        agent_class=dqn_agent.DqnAgent,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_py_env = suite_gym.load(env_name)

        q_net = q_network.QNetwork(tf_env.time_step_spec().observation,
                                   tf_env.action_spec(),
                                   fc_layer_params=fc_layer_params)

        tf_agent = agent_class(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        replay_observer = [replay_buffer.add_batch]
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps).run()

        collect_policy = tf_agent.collect_policy
        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        experience, _ = iterator.get_next()
        train_op = common.function(tf_agent.train)(experience=experience)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
                tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            global_step_call = sess.make_callable(global_step)
            train_step_call = sess.make_callable([train_op, summary_ops])

            timed_at_step = global_step_call()
            collect_time = 0
            train_time = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                # Train/collect/eval.
                start_time = time.time()
                collect_call()
                collect_time += time.time() - start_time
                start_time = time.time()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                train_time += time.time() - start_time

                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info('%.3f steps/sec', steps_per_sec)
                    logging.info(
                        '%s', 'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
示例#23
0
train_step_counter = tf.compat.v2.Variable(0)
agent = dqn_agent.DqnAgent(train_env.time_step_spec(),
                           train_env.action_spec(),
                           q_network=q_net,
                           optimizer=optimizer,
                           td_errors_loss_fn=common.element_wise_squared_loss,
                           train_step_counter=train_step_counter)
agent.initialize()

#Policy, written by Josh Gendein, from the tutorial in https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial
eval_policy = agent.policy
collect_policy = agent.collect_policy

#Training metrics, written by Josh Gendein, borrowed from the tutorial in https://towardsdatascience.com/tf-agents-tutorial-a63399218309
train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric(),
]


#Average return method, written by Josh Gendein, from the tutorial in https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial
def compute_avg_return(environment, policy, num_episodes=10):
    total_return = 0.0
    for _ in range(num_episodes):
        time_step = environment.reset()
        episode_return = 0.0
        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
示例#24
0
def train_eval(
        root_dir,
        tf_master='',
        env_name='HalfCheetah-v1',
        env_load_fn=suite_mujoco.load,
        random_seed=0,
        # TODO(kbanoop): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=50,
        rb_checkpoint_interval=200,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.contrib.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.contrib.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
        batched_py_metric.BatchedPyMetric(
            AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
    ]
    eval_summary_writer_flush_op = eval_summary_writer.flush()

    # TODO(kbanoop): Figure out if it is possible to avoid the with block.
    with tf.contrib.summary.record_summaries_every_n_global_steps(
            summary_interval):
        tf.set_random_seed(random_seed)
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(), fc_layer_params=value_fc_layers)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec(),
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(
            tf_agent.policy(), batch_size=num_parallel_environments)

        # TODO(sguada): Reenable metrics when ready for batch data.
        environment_steps_metric = tf_metrics.EnvironmentSteps()
        environment_steps_metric.build()
        environment_steps_count = environment_steps_metric.value()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # Add to replay buffer and other agent specific observers.
        replay_buffer_observer = [replay_buffer.add_batch]

        global_step = tf.train.get_or_create_global_step()
        collect_policy = tf_agent.collect_policy()

        policy_state = tf_agent.policy().get_initial_state(
            batch_size=tf_env.batch_size)

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_buffer_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration).run(
                policy_state=policy_state)

        trajectories = replay_buffer.gather_all()

        train_op = tf_agent.train(experience=trajectories,
                                  train_step_counter=global_step)

        with tf.control_dependencies([train_op]):
            clear_replay_op = replay_buffer.clear()

        with tf.control_dependencies([clear_replay_op]):
            train_op = tf.identity(train_op)

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=tf.contrib.checkpoint.List(train_metrics))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy(),
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries()
        summary_op = tf.contrib.summary.all_summary_ops()

        with eval_summary_writer.as_default(), \
             tf.contrib.summary.always_record_summaries():
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(step_metrics=step_metrics)

        init_agent_op = tf_agent.initialize()

        with tf.Session(tf_master) as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            tf.contrib.summary.initialize(session=sess)

            collect_time = 0
            train_time = 0
            timed_at_step = sess.run(global_step)
            steps_per_second_ph = tf.placeholder(tf.float32,
                                                 shape=(),
                                                 name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            while sess.run(environment_steps_count) < num_environment_steps:
                global_step_val = sess.run(global_step)
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
                    sess.run(eval_summary_writer_flush_op)

                start_time = time.time()
                sess.run(collect_op)
                collect_time += time.time() - start_time
                start_time = time.time()
                total_loss, _ = sess.run([train_op, summary_op])
                train_time += time.time() - start_time

                global_step_val = sess.run(global_step)
                if global_step_val % log_interval == 0:
                    tf.logging.info('step = %d, loss = %f', global_step_val,
                                    total_loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    tf.logging.info('%.3f steps/sec' % steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    tf.logging.info(
                        'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

            # One final eval before exiting.
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )
            sess.run(eval_summary_writer_flush_op)
示例#25
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_whitelist='position',
        num_iterations=100000,
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        # Params for collect
        initial_collect_steps=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=200,
        batch_size=64,
        train_sequence_length=10,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        gamma=0.995,
        reward_scale_factor=1.0,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=10000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_whitelist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_whitelist=[observations_whitelist])
            ]
        else:
            env_wrappers = []
        environment = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)
        tf_env = tf_py_environment.TFPyEnvironment(environment)
        eval_py_env = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)

        actor_net = actor_rnn_network.ActorRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers)

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_rnn_network.CriticRnnNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
        )

        tf_agent = td3_agent.Td3Agent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            debug_summaries=debug_summaries,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # TODO(oars): Refactor drivers to better handle policy states. Remove the
        # policy reset and passing down an empyt policy state to the driver.
        collect_policy = tf_agent.collect_policy
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)
        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps).run(policy_state=policy_state)

        policy_state = collect_policy.get_initial_state(tf_env.batch_size)
        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run(
                policy_state=policy_state)

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        trajectories, unused_info = iterator.get_next()
        train_op = tf_agent.train(experience=trajectories)

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy,
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(step_metrics=train_metrics[:2])

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(train_op)
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
示例#26
0
def train(
    root_dir,
    agent,
    environment,
    training_loops,
    steps_per_loop=1,
    additional_metrics=(),
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10,
    policy_checkpoint_interval=10,
    log_interval=10,
    summary_interval=10):
  """A training driver."""

  if not common.resource_variables_enabled():
    raise RuntimeError(common.MISSING_RESOURCE_VARIABLES_ERROR)

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(train_dir)
  train_summary_writer.set_as_default()

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(batch_size=environment.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            batch_size=environment.batch_size),
    ] + list(additional_metrics)

    # Add to replay buffer and other agent specific observers.
    replay_buffer = build_replay_buffer(agent, environment.batch_size,
                                        steps_per_loop)
    agent_observers = [replay_buffer.add_batch] + train_metrics

    driver = dynamic_step_driver.DynamicStepDriver(
        env=environment,
        policy=agent.policy,
        num_steps=steps_per_loop * environment.batch_size,
        observers=agent_observers)

    collect_op, _ = driver.run()
    batch_size = driver.env.batch_size
    dataset = replay_buffer.as_dataset(
        sample_batch_size=batch_size,
        num_steps=steps_per_loop,
        single_deterministic_pass=True)
    trajectories, unused_info = tf.data.experimental.get_single_element(dataset)
    train_op = agent.train(experience=trajectories)
    clear_replay_op = replay_buffer.clear()

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        max_to_keep=1,
        agent=agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        max_to_keep=None,
        policy=agent.policy,
        global_step=global_step)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(
          train_metric.tf_summaries(
              train_step=global_step, step_metrics=train_metrics[:2]))

    init_agent_op = agent.initialize()

    config_saver = utils.GinConfigSaverHook(train_dir, summarize_config=True)
    config_saver.begin()

    with tf.compat.v1.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      common.initialize_uninitialized_variables(sess)

      config_saver.after_create_session(sess)

      global_step_call = sess.make_callable(global_step)
      global_step_val = global_step_call()

      sess.run(train_summary_writer.init())
      sess.run(collect_op)

      if global_step_val == 0:
        # Save an initial checkpoint so the evaluator runs for global_step=0.
        policy_checkpointer.save(global_step=global_step_val)
        sess.run(init_agent_op)

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_ops])
      clear_replay_call = sess.make_callable(clear_replay_op)

      timed_at_step = global_step_val
      time_acc = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec',
          data=steps_per_second_ph,
          step=global_step)

      for _ in range(training_loops):
        # Collect and train.
        start_time = time.time()
        collect_call()
        total_loss, _ = train_step_call()
        clear_replay_call()
        global_step_val = global_step_call()

        time_acc += time.time() - start_time

        total_loss = total_loss.loss

        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val, total_loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)
示例#27
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        agent_class=dqn_agent.DqnAgent,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.contrib.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.contrib.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    with tf.contrib.summary.record_summaries_every_n_global_steps(
            summary_interval):

        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_py_env = suite_gym.load(env_name)

        q_net = q_network.QNetwork(tf_env.time_step_spec().observation,
                                   tf_env.action_spec(),
                                   fc_layer_params=fc_layer_params)

        tf_agent = agent_class(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate),
            # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec(),
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy())

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        global_step = tf.train.get_or_create_global_step()

        replay_observer = [replay_buffer.add_batch]
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer,
            num_steps=initial_collect_steps).run()

        collect_policy = tf_agent.collect_policy()
        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)

        iterator = dataset.make_initializable_iterator()
        trajectories, _ = iterator.get_next()
        train_op = tf_agent.train(experience=trajectories,
                                  train_step_counter=global_step)

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=tf.contrib.checkpoint.List(train_metrics))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy(),
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(step_metrics=train_metrics[:2])
        summary_op = tf.contrib.summary.all_summary_ops()

        with eval_summary_writer.as_default(), \
             tf.contrib.summary.always_record_summaries():
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            tf.contrib.summary.initialize(session=sess)
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(
                [train_op, summary_op, global_step])

            timed_at_step = sess.run(global_step)
            collect_time = 0
            train_time = 0
            steps_per_second_ph = tf.placeholder(tf.float32,
                                                 shape=(),
                                                 name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                # Train/collect/eval.
                start_time = time.time()
                collect_call()
                collect_time += time.time() - start_time
                start_time = time.time()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _, global_step_val = train_step_call()
                train_time += time.time() - start_time

                if global_step_val % log_interval == 0:
                    tf.logging.info('step = %d, loss = %f', global_step_val,
                                    loss_info_value.loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    tf.logging.info('%.3f steps/sec' % steps_per_sec)
                    tf.logging.info(
                        'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
示例#28
0
def train(root_dir,
          agent,
          environment,
          training_loops,
          steps_per_loop,
          additional_metrics=(),
          training_data_spec_transformation_fn=None):
  """Perform `training_loops` iterations of training.

  Checkpoint results.

  If one or more baseline_reward_fns are provided, the regret is computed
  against each one of them. Here is example baseline_reward_fn:

  def baseline_reward_fn(observation, per_action_reward_fns):
   rewards = ... # compute reward for each arm
   optimal_action_reward = ... # take the maximum reward
   return optimal_action_reward

  Args:
    root_dir: path to the directory where checkpoints and metrics will be
      written.
    agent: an instance of `TFAgent`.
    environment: an instance of `TFEnvironment`.
    training_loops: an integer indicating how many training loops should be run.
    steps_per_loop: an integer indicating how many driver steps should be
      executed and presented to the trainer during each training loop.
    additional_metrics: Tuple of metric objects to log, in addition to default
      metrics `NumberOfEpisodes`, `AverageReturnMetric`, and
      `AverageEpisodeLengthMetric`.
    training_data_spec_transformation_fn: Optional function that transforms the
    data items before they get to the replay buffer.
  """

  # TODO(b/127641485): create evaluation loop with configurable metrics.
  if training_data_spec_transformation_fn is None:
    data_spec = agent.policy.trajectory_spec
  else:
    data_spec = training_data_spec_transformation_fn(
        agent.policy.trajectory_spec)
  replay_buffer = get_replay_buffer(data_spec, environment.batch_size,
                                    steps_per_loop)

  # `step_metric` records the number of individual rounds of bandit interaction;
  # that is, (number of trajectories) * batch_size.
  step_metric = tf_metrics.EnvironmentSteps()
  metrics = [
      tf_metrics.NumberOfEpisodes(),
      tf_metrics.AverageEpisodeLengthMetric(batch_size=environment.batch_size)
  ] + list(additional_metrics)

  if isinstance(environment.reward_spec(), dict):
    metrics += [tf_metrics.AverageReturnMultiMetric(
        reward_spec=environment.reward_spec(),
        batch_size=environment.batch_size)]
  else:
    metrics += [
        tf_metrics.AverageReturnMetric(batch_size=environment.batch_size)]

  if training_data_spec_transformation_fn is not None:
    add_batch_fn = lambda data: replay_buffer.add_batch(  # pylint: disable=g-long-lambda
        training_data_spec_transformation_fn(data))
  else:
    add_batch_fn = replay_buffer.add_batch

  observers = [add_batch_fn, step_metric] + metrics

  driver = dynamic_step_driver.DynamicStepDriver(
      env=environment,
      policy=agent.collect_policy,
      num_steps=steps_per_loop * environment.batch_size,
      observers=observers)

  training_loop = get_training_loop_fn(
      driver, replay_buffer, agent, steps_per_loop)
  checkpoint_manager = restore_and_get_checkpoint_manager(
      root_dir, agent, metrics, step_metric)
  saver = policy_saver.PolicySaver(agent.policy)

  summary_writer = tf.summary.create_file_writer(root_dir)
  summary_writer.set_as_default()
  for _ in range(training_loops):
    training_loop()
    metric_utils.log_metrics(metrics)
    for metric in metrics:
      metric.tf_summaries(train_step=step_metric.result())
    checkpoint_manager.save()
    saver.save(os.path.join(root_dir, 'policy_%d' % step_metric.result()))
示例#29
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        num_iterations=1000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_mujoco.load(env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=normal_projection_net)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
            tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
示例#30
0
def train_eval(
    root_dir,
    env_name='sawyer_reach',
    num_iterations=3000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    gamma=0.99,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=200000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    random_seed=0,
    max_future_steps=50,
    actor_std=None,
    log_subset=None,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env, eval_tf_env, obs_dim = c_learning_envs.load(env_name)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        if actor_std is None:
            proj_net = tanh_normal_projection_network.TanhNormalProjectionNetwork
        else:
            proj_net = functools.partial(
                tanh_normal_projection_network.TanhNormalProjectionNetwork,
                std_transform=lambda t: actor_std * tf.ones_like(t))

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=proj_net)
        critic_net = c_learning_utils.ClassifierCriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = c_learning_agent.CLearningAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=bce_loss,
            gamma=gamma,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes),
            c_learning_utils.FinalDistance(buffer_size=num_eval_episodes,
                                           obs_dim=obs_dim),
            c_learning_utils.MinimumDistance(buffer_size=num_eval_episodes,
                                             obs_dim=obs_dim),
            c_learning_utils.DeltaDistance(buffer_size=num_eval_episodes,
                                           obs_dim=obs_dim),
        ]
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
            c_learning_utils.InitialDistance(buffer_size=num_eval_episodes,
                                             batch_size=tf_env.batch_size,
                                             obs_dim=obs_dim),
            c_learning_utils.FinalDistance(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size,
                                           obs_dim=obs_dim),
            c_learning_utils.MinimumDistance(buffer_size=num_eval_episodes,
                                             batch_size=tf_env.batch_size,
                                             obs_dim=obs_dim),
            c_learning_utils.DeltaDistance(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size,
                                           obs_dim=obs_dim),
        ]
        if log_subset is not None:
            start_index, end_index = log_subset
            for name, metrics in [('train', train_metrics),
                                  ('eval', eval_metrics)]:
                metrics.extend([
                    c_learning_utils.InitialDistance(
                        buffer_size=num_eval_episodes,
                        batch_size=tf_env.batch_size
                        if name == 'train' else 10,
                        obs_dim=obs_dim,
                        start_index=start_index,
                        end_index=end_index,
                        name='SubsetInitialDistance'),
                    c_learning_utils.FinalDistance(
                        buffer_size=num_eval_episodes,
                        batch_size=tf_env.batch_size
                        if name == 'train' else 10,
                        obs_dim=obs_dim,
                        start_index=start_index,
                        end_index=end_index,
                        name='SubsetFinalDistance'),
                    c_learning_utils.MinimumDistance(
                        buffer_size=num_eval_episodes,
                        batch_size=tf_env.batch_size
                        if name == 'train' else 10,
                        obs_dim=obs_dim,
                        start_index=start_index,
                        end_index=end_index,
                        name='SubsetMinimumDistance'),
                    c_learning_utils.DeltaDistance(
                        buffer_size=num_eval_episodes,
                        batch_size=tf_env.batch_size
                        if name == 'train' else 10,
                        obs_dim=obs_dim,
                        start_index=start_index,
                        end_index=end_index,
                        name='SubsetDeltaDistance'),
                ])

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=None)

        train_checkpointer.initialize_or_restore()

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Save the hyperparameters
        operative_filename = os.path.join(root_dir, 'operative.gin')
        with tf.compat.v1.gfile.Open(operative_filename, 'w') as f:
            f.write(gin.operative_config_str())
            logging.info(gin.operative_config_str())

        if replay_buffer.num_frames() == 0:
            # Collect initial replay data.
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps '
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(sample_batch_size=batch_size,
                                           num_steps=max_future_steps)
        dataset = dataset.unbatch().filter(_filter_invalid_transition)
        dataset = dataset.batch(batch_size, drop_remainder=True)
        goal_fn = functools.partial(c_learning_utils.goal_fn,
                                    batch_size=batch_size,
                                    obs_dim=obs_dim,
                                    gamma=gamma)
        dataset = dataset.map(goal_fn)
        dataset = dataset.prefetch(5)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        global_step_val = global_step.numpy()
        while global_step_val < num_iterations:
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                metric_utils.log_metrics(eval_metrics)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

        return train_loss