Пример #1
0
    def test_ppo(self):
        env_class = PolicyUnittestEnv
        learning_rate = 1e-1
        iterations = 20
        batch_size = 100
        steps_per_episode = 13
        env = env_class(batch_size, steps_per_episode)
        env = TFPyEnvironment(env)

        eval_env = env_class(batch_size, steps_per_episode)
        eval_env = TFPyEnvironment(eval_env)

        algorithm = create_algorithm(env, learning_rate=learning_rate)
        driver = SyncOffPolicyDriver(env,
                                     algorithm,
                                     debug_summaries=DEBUGGING,
                                     summarize_grads_and_vars=DEBUGGING)
        replayer = driver.exp_replayer
        eval_driver = OnPolicyDriver(eval_env,
                                     algorithm,
                                     training=False,
                                     greedy_predict=True)

        env.reset()
        eval_env.reset()
        time_step = driver.get_initial_time_step()
        policy_state = driver.get_initial_policy_state()
        for i in range(iterations):
            time_step, policy_state = driver.run(max_num_steps=batch_size *
                                                 steps_per_episode,
                                                 time_step=time_step,
                                                 policy_state=policy_state)

            experience = replayer.replay_all()
            driver.train(experience, num_updates=4, mini_batch_size=25)
            replayer.clear()
            eval_env.reset()
            eval_time_step, _ = eval_driver.run(
                max_num_steps=(steps_per_episode - 1) * batch_size)
            logging.info("%d reward=%f", i,
                         float(tf.reduce_mean(eval_time_step.reward)))

        eval_env.reset()
        eval_time_step, _ = eval_driver.run(
            max_num_steps=(steps_per_episode - 1) * batch_size)
        logging.info("reward=%f", float(tf.reduce_mean(eval_time_step.reward)))
        self.assertAlmostEqual(1.0,
                               float(tf.reduce_mean(eval_time_step.reward)),
                               delta=1e-1)
Пример #2
0
    def test_off_policy_algorithm(self, algorithm_ctor, use_rollout_state,
                                  sync_driver):
        logging.info("{} {}".format(algorithm_ctor.__name__, sync_driver))

        batch_size = 128
        if use_rollout_state:
            steps_per_episode = 5
            mini_batch_length = 8
            unroll_length = 8
            env_class = RNNPolicyUnittestEnv
        else:
            steps_per_episode = 12
            mini_batch_length = 2
            unroll_length = 12
            env_class = PolicyUnittestEnv
        env = TFPyEnvironment(
            env_class(batch_size,
                      steps_per_episode,
                      action_type=ActionType.Continuous))

        eval_env = TFPyEnvironment(
            env_class(batch_size,
                      steps_per_episode,
                      action_type=ActionType.Continuous))

        algorithm = algorithm_ctor(env)
        algorithm.use_rollout_state = use_rollout_state

        if sync_driver:
            driver = SyncOffPolicyDriver(env,
                                         algorithm,
                                         use_rollout_state=use_rollout_state,
                                         debug_summaries=True,
                                         summarize_grads_and_vars=True)
        else:
            driver = AsyncOffPolicyDriver(
                [env],
                algorithm,
                use_rollout_state=algorithm.use_rollout_state,
                num_actor_queues=1,
                unroll_length=unroll_length,
                learn_queue_cap=1,
                actor_queue_cap=1,
                debug_summaries=True,
                summarize_grads_and_vars=True)
        replayer = driver.exp_replayer
        eval_driver = OnPolicyDriver(eval_env,
                                     algorithm,
                                     training=False,
                                     greedy_predict=True)

        eval_env.reset()
        driver.start()
        if sync_driver:
            time_step = driver.get_initial_time_step()
            policy_state = driver.get_initial_policy_state()
            for i in range(5):
                time_step, policy_state = driver.run(max_num_steps=batch_size *
                                                     steps_per_episode,
                                                     time_step=time_step,
                                                     policy_state=policy_state)

        for i in range(500):
            if sync_driver:
                time_step, policy_state = driver.run(max_num_steps=batch_size *
                                                     mini_batch_length * 2,
                                                     time_step=time_step,
                                                     policy_state=policy_state)
                experience, _ = replayer.replay(
                    sample_batch_size=128, mini_batch_length=mini_batch_length)
            else:
                driver.run_async()
                experience = replayer.replay_all()

            driver.train(experience,
                         mini_batch_size=128,
                         mini_batch_length=mini_batch_length)
            eval_env.reset()
            eval_time_step, _ = eval_driver.run(
                max_num_steps=(steps_per_episode - 1) * batch_size)
            logging.info("%d reward=%f", i,
                         float(tf.reduce_mean(eval_time_step.reward)))
        driver.stop()

        self.assertAlmostEqual(1.0,
                               float(tf.reduce_mean(eval_time_step.reward)),
                               delta=2e-1)