示例#1
0
    def test_off_policy_algorithm(self, algorithm_ctor, use_rollout_state,
                                  sync_driver):
        logging.info("{} {}".format(algorithm_ctor.__name__, sync_driver))

        batch_size = 128
        if use_rollout_state:
            steps_per_episode = 5
            mini_batch_length = 8
            unroll_length = 8
            env_class = RNNPolicyUnittestEnv
        else:
            steps_per_episode = 12
            mini_batch_length = 2
            unroll_length = 12
            env_class = PolicyUnittestEnv
        env = TFPyEnvironment(
            env_class(
                batch_size,
                steps_per_episode,
                action_type=ActionType.Continuous))

        eval_env = TFPyEnvironment(
            env_class(
                batch_size,
                steps_per_episode,
                action_type=ActionType.Continuous))

        common.set_global_env(env)
        algorithm = algorithm_ctor()
        algorithm.set_summary_settings(summarize_grads_and_vars=True)
        algorithm.use_rollout_state = use_rollout_state

        if sync_driver:
            driver = SyncOffPolicyDriver(env, algorithm)
        else:
            driver = AsyncOffPolicyDriver([env],
                                          algorithm,
                                          num_actor_queues=1,
                                          unroll_length=unroll_length,
                                          learn_queue_cap=1,
                                          actor_queue_cap=1)
        eval_driver = OnPolicyDriver(eval_env, algorithm, training=False)

        eval_env.reset()
        driver.start()
        if sync_driver:
            time_step = driver.get_initial_time_step()
            policy_state = driver.get_initial_policy_state()
            for i in range(5):
                time_step, policy_state = driver.run(
                    max_num_steps=batch_size * steps_per_episode,
                    time_step=time_step,
                    policy_state=policy_state)

        for i in range(500):
            if sync_driver:
                time_step, policy_state = driver.run(
                    max_num_steps=batch_size * mini_batch_length * 2,
                    time_step=time_step,
                    policy_state=policy_state)
                whole_replay_buffer_training = False
                clear_replay_buffer = False
            else:
                driver.run_async()
                whole_replay_buffer_training = True
                clear_replay_buffer = True

            driver.algorithm.train(
                mini_batch_size=128,
                mini_batch_length=mini_batch_length,
                whole_replay_buffer_training=whole_replay_buffer_training,
                clear_replay_buffer=clear_replay_buffer)
            eval_env.reset()
            eval_time_step, _ = eval_driver.run(
                max_num_steps=(steps_per_episode - 1) * batch_size)
            logging.log_every_n_seconds(
                logging.INFO,
                "%d reward=%f" %
                (i, float(tf.reduce_mean(eval_time_step.reward))),
                n_seconds=1)
        driver.stop()

        self.assertAlmostEqual(
            1.0, float(tf.reduce_mean(eval_time_step.reward)), delta=2e-1)
示例#2
0
    def test_off_policy_algorithm(self, algorithm_ctor, use_rollout_state,
                                  sync_driver):
        logging.info("{} {}".format(algorithm_ctor.__name__, sync_driver))

        batch_size = 128
        if use_rollout_state:
            steps_per_episode = 5
            mini_batch_length = 8
            unroll_length = 8
            env_class = RNNPolicyUnittestEnv
        else:
            steps_per_episode = 12
            mini_batch_length = 2
            unroll_length = 12
            env_class = PolicyUnittestEnv
        env = TFPyEnvironment(
            env_class(batch_size,
                      steps_per_episode,
                      action_type=ActionType.Continuous))

        eval_env = TFPyEnvironment(
            env_class(batch_size,
                      steps_per_episode,
                      action_type=ActionType.Continuous))

        algorithm = algorithm_ctor(env)
        algorithm.use_rollout_state = use_rollout_state

        if sync_driver:
            driver = SyncOffPolicyDriver(env,
                                         algorithm,
                                         use_rollout_state=use_rollout_state,
                                         debug_summaries=True,
                                         summarize_grads_and_vars=True)
        else:
            driver = AsyncOffPolicyDriver(
                [env],
                algorithm,
                use_rollout_state=algorithm.use_rollout_state,
                num_actor_queues=1,
                unroll_length=unroll_length,
                learn_queue_cap=1,
                actor_queue_cap=1,
                debug_summaries=True,
                summarize_grads_and_vars=True)
        replayer = driver.exp_replayer
        eval_driver = OnPolicyDriver(eval_env,
                                     algorithm,
                                     training=False,
                                     greedy_predict=True)

        eval_env.reset()
        driver.start()
        if sync_driver:
            time_step = driver.get_initial_time_step()
            policy_state = driver.get_initial_policy_state()
            for i in range(5):
                time_step, policy_state = driver.run(max_num_steps=batch_size *
                                                     steps_per_episode,
                                                     time_step=time_step,
                                                     policy_state=policy_state)

        for i in range(500):
            if sync_driver:
                time_step, policy_state = driver.run(max_num_steps=batch_size *
                                                     mini_batch_length * 2,
                                                     time_step=time_step,
                                                     policy_state=policy_state)
                experience, _ = replayer.replay(
                    sample_batch_size=128, mini_batch_length=mini_batch_length)
            else:
                driver.run_async()
                experience = replayer.replay_all()

            driver.train(experience,
                         mini_batch_size=128,
                         mini_batch_length=mini_batch_length)
            eval_env.reset()
            eval_time_step, _ = eval_driver.run(
                max_num_steps=(steps_per_episode - 1) * batch_size)
            logging.info("%d reward=%f", i,
                         float(tf.reduce_mean(eval_time_step.reward)))
        driver.stop()

        self.assertAlmostEqual(1.0,
                               float(tf.reduce_mean(eval_time_step.reward)),
                               delta=2e-1)