コード例 #1
0
ファイル: on_policy_driver_test.py プロジェクト: runjerry/alf
    def test_actor_critic_continuous_policy(self):
        batch_size = 100
        steps_per_episode = 13
        env = PolicyUnittestEnv(batch_size,
                                steps_per_episode,
                                action_type=ActionType.Continuous)
        # We need to wrap env using TFPyEnvironment because the methods of env
        # has side effects (e.g, env._current_time_step can be changed)
        env = TFPyEnvironment(env)
        action_spec = env.action_spec()
        observation_spec = env.observation_spec()
        algorithm = ActorCriticAlgorithm(
            observation_spec=observation_spec,
            action_spec=action_spec,
            actor_network=ActorDistributionNetwork(observation_spec,
                                                   action_spec,
                                                   fc_layer_params=()),
            value_network=ValueNetwork(observation_spec, fc_layer_params=()),
            optimizer=tf.optimizers.Adam(learning_rate=1e-2))
        driver = OnPolicyDriver(env, algorithm, train_interval=2)
        eval_driver = OnPolicyDriver(env, algorithm, training=False)

        driver.run = tf.function(driver.run)

        t0 = time.time()
        driver.run(max_num_steps=2600 * batch_size)
        print("time=%s" % (time.time() - t0))

        env.reset()
        time_step, _ = eval_driver.run(max_num_steps=4 * batch_size)
        print("reward=%s" % tf.reduce_mean(time_step.reward))
        self.assertAlmostEqual(1.0,
                               float(tf.reduce_mean(time_step.reward)),
                               delta=5e-2)
コード例 #2
0
ファイル: merlin_algorithm_test.py プロジェクト: runjerry/alf
    def test_merlin_algorithm(self):
        batch_size = 100
        steps_per_episode = 15
        gap = 10
        env = RNNPolicyUnittestEnv(
            batch_size, steps_per_episode, gap, obs_dim=3)
        env = TFPyEnvironment(env)

        common.set_global_env(env)

        algorithm = _create_merlin_algorithm(
            learning_rate=1e-3, debug_summaries=False)
        driver = OnPolicyDriver(env, algorithm, train_interval=6)

        eval_driver = OnPolicyDriver(env, algorithm, training=False)

        proc = psutil.Process(os.getpid())

        policy_state = driver.get_initial_policy_state()
        time_step = driver.get_initial_time_step()
        for i in range(100):
            t0 = time.time()
            time_step, policy_state, _ = driver.run(
                max_num_steps=150 * batch_size,
                time_step=time_step,
                policy_state=policy_state)
            mem = proc.memory_info().rss // 1e6
            logging.info('%s time=%.3f mem=%s' % (i, time.time() - t0, mem))

        env.reset()
        time_step, _ = eval_driver.run(max_num_steps=14 * batch_size)
        logging.info("eval reward=%.3f" % tf.reduce_mean(time_step.reward))
        self.assertAlmostEqual(
            1.0, float(tf.reduce_mean(time_step.reward)), delta=1e-2)
コード例 #3
0
    def test_ppo(self):
        env_class = PolicyUnittestEnv
        learning_rate = 1e-1
        iterations = 20
        batch_size = 100
        steps_per_episode = 13
        env = env_class(batch_size, steps_per_episode)
        env = TFPyEnvironment(env)

        eval_env = env_class(batch_size, steps_per_episode)
        eval_env = TFPyEnvironment(eval_env)

        algorithm = create_algorithm(env, learning_rate=learning_rate)
        driver = SyncOffPolicyDriver(env,
                                     algorithm,
                                     debug_summaries=DEBUGGING,
                                     summarize_grads_and_vars=DEBUGGING)
        replayer = driver.exp_replayer
        eval_driver = OnPolicyDriver(eval_env,
                                     algorithm,
                                     training=False,
                                     greedy_predict=True)

        env.reset()
        eval_env.reset()
        time_step = driver.get_initial_time_step()
        policy_state = driver.get_initial_policy_state()
        for i in range(iterations):
            time_step, policy_state = driver.run(max_num_steps=batch_size *
                                                 steps_per_episode,
                                                 time_step=time_step,
                                                 policy_state=policy_state)

            experience = replayer.replay_all()
            driver.train(experience, num_updates=4, mini_batch_size=25)
            replayer.clear()
            eval_env.reset()
            eval_time_step, _ = eval_driver.run(
                max_num_steps=(steps_per_episode - 1) * batch_size)
            logging.info("%d reward=%f", i,
                         float(tf.reduce_mean(eval_time_step.reward)))

        eval_env.reset()
        eval_time_step, _ = eval_driver.run(
            max_num_steps=(steps_per_episode - 1) * batch_size)
        logging.info("reward=%f", float(tf.reduce_mean(eval_time_step.reward)))
        self.assertAlmostEqual(1.0,
                               float(tf.reduce_mean(eval_time_step.reward)),
                               delta=1e-1)
コード例 #4
0
 def init_driver(self):
     return OnPolicyDriver(
         env=self._env,
         algorithm=self._algorithm,
         train_interval=self._unroll_length,
         debug_summaries=self._debug_summaries,
         summarize_grads_and_vars=self._summarize_grads_and_vars)
コード例 #5
0
 def init_driver(self):
     return OnPolicyDriver(env=self._envs[0],
                           algorithm=self._algorithm,
                           train_interval=self._unroll_length)
コード例 #6
0
def play(root_dir,
         env,
         algorithm,
         checkpoint_name=None,
         greedy_predict=True,
         random_seed=None,
         num_episodes=10,
         sleep_time_per_step=0.01,
         record_file=None,
         use_tf_functions=True):
    """Play using the latest checkpoint under `train_dir`.

    The following example record the play of a trained model to a mp4 video:
    ```bash
    python -m alf.bin.play \
    --root_dir=~/tmp/bullet_humanoid/ppo2/ppo2-11 \
    --num_episodes=1 \
    --record_file=ppo_bullet_humanoid.mp4
    ```
    Args:
        root_dir (str): same as the root_dir used for `train()`
        env (TFEnvironment): the environment
        algorithm (OnPolicyAlgorithm): the training algorithm
        checkpoint_name (str): name of the checkpoint (e.g. 'ckpt-12800`).
            If None, the latest checkpoint under train_dir will be used.
        greedy_predict (bool): use greedy action for evaluation.
        random_seed (None|int): random seed, a random seed is used if None
        num_episodes (int): number of episodes to play
        sleep_time_per_step (float): sleep so many seconds for each step
        record_file (str): if provided, video will be recorded to a file
            instead of shown on the screen.
        use_tf_functions (bool): whether to use tf.function
    """
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')

    if random_seed is not None:
        random.seed(random_seed)
        np.random.seed(random_seed)
        tf.random.set_seed(random_seed)

    global_step = get_global_counter()

    driver = OnPolicyDriver(env=env,
                            algorithm=algorithm,
                            training=False,
                            greedy_predict=greedy_predict)

    ckpt_dir = os.path.join(train_dir, 'algorithm')
    checkpoint = tf.train.Checkpoint(algorithm=algorithm,
                                     metrics=metric_utils.MetricsGroup(
                                         driver.get_metrics(), 'metrics'),
                                     global_step=global_step)
    if checkpoint_name is not None:
        ckpt_path = os.path.join(ckpt_dir, checkpoint_name)
    else:
        ckpt_path = tf.train.latest_checkpoint(ckpt_dir)
    if ckpt_path is not None:
        logging.info("Restore from checkpoint %s" % ckpt_path)
        checkpoint.restore(ckpt_path)
    else:
        logging.info("Checkpoint is not found at %s" % ckpt_dir)

    if not use_tf_functions:
        tf.config.experimental_run_functions_eagerly(True)

    recorder = None
    if record_file is not None:
        recorder = VideoRecorder(env.pyenv.envs[0], path=record_file)
    else:
        # pybullet_envs need to render() before reset() to enable mode='human'
        env.pyenv.envs[0].render(mode='human')
    env.reset()
    if recorder:
        recorder.capture_frame()
    time_step = driver.get_initial_time_step()
    policy_state = driver.get_initial_policy_state()
    episode_reward = 0.
    episode_length = 0
    episodes = 0
    while episodes < num_episodes:
        time_step, policy_state = driver.run(max_num_steps=1,
                                             time_step=time_step,
                                             policy_state=policy_state)
        if recorder:
            recorder.capture_frame()
        else:
            env.pyenv.envs[0].render(mode='human')
            time.sleep(sleep_time_per_step)

        episode_reward += float(time_step.reward)

        if time_step.is_last():
            logging.info("episode_length=%s episode_reward=%s" %
                         (episode_length, episode_reward))
            episode_reward = 0.
            episode_length = 0.
            episodes += 1
        else:
            episode_length += 1
    if recorder:
        recorder.close()
    env.reset()
コード例 #7
0
    def test_off_policy_algorithm(self, algorithm_ctor, use_rollout_state,
                                  sync_driver):
        logging.info("{} {}".format(algorithm_ctor.__name__, sync_driver))

        batch_size = 128
        if use_rollout_state:
            steps_per_episode = 5
            mini_batch_length = 8
            unroll_length = 8
            env_class = RNNPolicyUnittestEnv
        else:
            steps_per_episode = 12
            mini_batch_length = 2
            unroll_length = 12
            env_class = PolicyUnittestEnv
        env = TFPyEnvironment(
            env_class(
                batch_size,
                steps_per_episode,
                action_type=ActionType.Continuous))

        eval_env = TFPyEnvironment(
            env_class(
                batch_size,
                steps_per_episode,
                action_type=ActionType.Continuous))

        common.set_global_env(env)
        algorithm = algorithm_ctor()
        algorithm.set_summary_settings(summarize_grads_and_vars=True)
        algorithm.use_rollout_state = use_rollout_state

        if sync_driver:
            driver = SyncOffPolicyDriver(env, algorithm)
        else:
            driver = AsyncOffPolicyDriver([env],
                                          algorithm,
                                          num_actor_queues=1,
                                          unroll_length=unroll_length,
                                          learn_queue_cap=1,
                                          actor_queue_cap=1)
        eval_driver = OnPolicyDriver(eval_env, algorithm, training=False)

        eval_env.reset()
        driver.start()
        if sync_driver:
            time_step = driver.get_initial_time_step()
            policy_state = driver.get_initial_policy_state()
            for i in range(5):
                time_step, policy_state = driver.run(
                    max_num_steps=batch_size * steps_per_episode,
                    time_step=time_step,
                    policy_state=policy_state)

        for i in range(500):
            if sync_driver:
                time_step, policy_state = driver.run(
                    max_num_steps=batch_size * mini_batch_length * 2,
                    time_step=time_step,
                    policy_state=policy_state)
                whole_replay_buffer_training = False
                clear_replay_buffer = False
            else:
                driver.run_async()
                whole_replay_buffer_training = True
                clear_replay_buffer = True

            driver.algorithm.train(
                mini_batch_size=128,
                mini_batch_length=mini_batch_length,
                whole_replay_buffer_training=whole_replay_buffer_training,
                clear_replay_buffer=clear_replay_buffer)
            eval_env.reset()
            eval_time_step, _ = eval_driver.run(
                max_num_steps=(steps_per_episode - 1) * batch_size)
            logging.log_every_n_seconds(
                logging.INFO,
                "%d reward=%f" %
                (i, float(tf.reduce_mean(eval_time_step.reward))),
                n_seconds=1)
        driver.stop()

        self.assertAlmostEqual(
            1.0, float(tf.reduce_mean(eval_time_step.reward)), delta=2e-1)
コード例 #8
0
    def test_off_policy_algorithm(self, algorithm_ctor, use_rollout_state,
                                  sync_driver):
        logging.info("{} {}".format(algorithm_ctor.__name__, sync_driver))

        batch_size = 128
        if use_rollout_state:
            steps_per_episode = 5
            mini_batch_length = 8
            unroll_length = 8
            env_class = RNNPolicyUnittestEnv
        else:
            steps_per_episode = 12
            mini_batch_length = 2
            unroll_length = 12
            env_class = PolicyUnittestEnv
        env = TFPyEnvironment(
            env_class(batch_size,
                      steps_per_episode,
                      action_type=ActionType.Continuous))

        eval_env = TFPyEnvironment(
            env_class(batch_size,
                      steps_per_episode,
                      action_type=ActionType.Continuous))

        algorithm = algorithm_ctor(env)
        algorithm.use_rollout_state = use_rollout_state

        if sync_driver:
            driver = SyncOffPolicyDriver(env,
                                         algorithm,
                                         use_rollout_state=use_rollout_state,
                                         debug_summaries=True,
                                         summarize_grads_and_vars=True)
        else:
            driver = AsyncOffPolicyDriver(
                [env],
                algorithm,
                use_rollout_state=algorithm.use_rollout_state,
                num_actor_queues=1,
                unroll_length=unroll_length,
                learn_queue_cap=1,
                actor_queue_cap=1,
                debug_summaries=True,
                summarize_grads_and_vars=True)
        replayer = driver.exp_replayer
        eval_driver = OnPolicyDriver(eval_env,
                                     algorithm,
                                     training=False,
                                     greedy_predict=True)

        eval_env.reset()
        driver.start()
        if sync_driver:
            time_step = driver.get_initial_time_step()
            policy_state = driver.get_initial_policy_state()
            for i in range(5):
                time_step, policy_state = driver.run(max_num_steps=batch_size *
                                                     steps_per_episode,
                                                     time_step=time_step,
                                                     policy_state=policy_state)

        for i in range(500):
            if sync_driver:
                time_step, policy_state = driver.run(max_num_steps=batch_size *
                                                     mini_batch_length * 2,
                                                     time_step=time_step,
                                                     policy_state=policy_state)
                experience, _ = replayer.replay(
                    sample_batch_size=128, mini_batch_length=mini_batch_length)
            else:
                driver.run_async()
                experience = replayer.replay_all()

            driver.train(experience,
                         mini_batch_size=128,
                         mini_batch_length=mini_batch_length)
            eval_env.reset()
            eval_time_step, _ = eval_driver.run(
                max_num_steps=(steps_per_episode - 1) * batch_size)
            logging.info("%d reward=%f", i,
                         float(tf.reduce_mean(eval_time_step.reward)))
        driver.stop()

        self.assertAlmostEqual(1.0,
                               float(tf.reduce_mean(eval_time_step.reward)),
                               delta=2e-1)