示例#1
0
def main(_):
    seed = common.set_random_seed(FLAGS.random_seed)
    gin_file = common.get_gin_file()
    gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param)
    algorithm_ctor = gin.query_parameter(
        'TrainerConfig.algorithm_ctor').scoped_configurable_fn
    env = create_environment(nonparallel=True, seed=seed)
    env.reset()
    common.set_global_env(env)
    config = policy_trainer.TrainerConfig(root_dir="")
    data_transformer = create_data_transformer(config.data_transformer_ctor,
                                               env.observation_spec())
    config.data_transformer = data_transformer
    observation_spec = data_transformer.transformed_observation_spec
    common.set_transformed_observation_spec(observation_spec)
    algorithm = algorithm_ctor(
        observation_spec=observation_spec,
        action_spec=env.action_spec(),
        config=config)
    try:
        policy_trainer.play(
            FLAGS.root_dir,
            env,
            algorithm,
            checkpoint_step=FLAGS.checkpoint_step or "latest",
            epsilon_greedy=FLAGS.epsilon_greedy,
            num_episodes=FLAGS.num_episodes,
            max_episode_length=FLAGS.max_episode_length,
            sleep_time_per_step=FLAGS.sleep_time_per_step,
            record_file=FLAGS.record_file,
            ignored_parameter_prefixes=FLAGS.ignored_parameter_prefixes.split(
                ",") if FLAGS.ignored_parameter_prefixes else [])
    finally:
        env.close()
示例#2
0
    def initialize(self):
        """Initializes the Trainer."""
        if self._random_seed is not None:
            random.seed(self._random_seed)
            np.random.seed(self._random_seed)
            tf.random.set_seed(self._random_seed)

        tf.config.experimental_run_functions_eagerly(
            not self._use_tf_functions)
        env = self._create_environment()
        common.set_global_env(env)

        self._algorithm = self._algorithm_ctor(
            debug_summaries=self._debug_summaries)
        self._algorithm.set_summary_settings(
            summarize_grads_and_vars=self._summarize_grads_and_vars,
            summarize_action_distributions=self._config.
            summarize_action_distributions)
        self._algorithm.use_rollout_state = self._config.use_rollout_state

        self._driver = self.init_driver()

        # Create an unwrapped env to expose subprocess gin confs which otherwise
        # will be marked as "inoperative". This env should be created last.
        unwrapped_env = self._create_environment(nonparallel=True)
        if self._evaluate:
            self._eval_env = unwrapped_env
示例#3
0
    def test_alf_metrics(self, num_envs, learn_queue_cap, unroll_length,
                         actor_queue_cap, num_actors, num_iterations):
        episode_length = 5
        env_f = lambda: TFPyEnvironment(
            ValueUnittestEnv(batch_size=1, episode_length=episode_length))

        envs = [env_f() for _ in range(num_envs)]
        common.set_global_env(envs[0])
        alg = _create_ac_algorithm()
        driver = AsyncOffPolicyDriver(envs, alg, num_actors, unroll_length,
                                      learn_queue_cap, actor_queue_cap)
        driver.start()
        total_num_steps_ = 0
        for _ in range(num_iterations):
            total_num_steps_ += driver.run_async()
        driver.stop()

        total_num_steps = int(driver.get_metrics()[1].result())
        self.assertGreaterEqual(total_num_steps_, total_num_steps)

        # An exp is only put in the log queue after it's put in the learning queue
        # So when we stop the driver (which will force all queues to stop),
        # some exps might be missing from the metric. Here we assert an arbitrary
        # lower bound of 2/5. The upper bound is due to the fact that StepType.LAST
        # is not recorded by the metric (episode_length==5).
        self.assertLessEqual(total_num_steps, int(total_num_steps_ * 4 // 5))
        self.assertGreaterEqual(total_num_steps,
                                int(total_num_steps_ * 2 // 5))

        average_reward = int(driver.get_metrics()[2].result())
        self.assertEqual(average_reward, episode_length - 1)

        episode_length = int(driver.get_metrics()[3].result())
        self.assertEqual(episode_length, episode_length)
示例#4
0
    def initialize(self):
        """Initializes the Trainer."""
        self._random_seed = common.set_random_seed(self._random_seed,
                                                   not self._use_tf_functions)

        tf.config.experimental_run_functions_eagerly(
            not self._use_tf_functions)
        env = self._create_environment(random_seed=self._random_seed)
        common.set_global_env(env)

        self._algorithm = self._algorithm_ctor(
            observation_spec=env.observation_spec(),
            action_spec=env.action_spec(),
            debug_summaries=self._debug_summaries)
        self._algorithm.set_summary_settings(
            summarize_grads_and_vars=self._summarize_grads_and_vars,
            summarize_action_distributions=self._config.
            summarize_action_distributions)
        self._algorithm.use_rollout_state = self._config.use_rollout_state

        self._driver = self._init_driver()

        # Create an unwrapped env to expose subprocess gin confs which otherwise
        # will be marked as "inoperative". This env should be created last.
        # DO NOT register this env in self._envs because AsyncOffPolicyTrainer
        # will use all self._envs to init AsyncOffPolicyDriver!
        self._unwrapped_env = self._create_environment(
            nonparallel=True, random_seed=self._random_seed, register=False)
        if self._evaluate:
            self._eval_env = self._unwrapped_env
示例#5
0
    def test_merlin_algorithm(self):
        batch_size = 100
        steps_per_episode = 15
        gap = 10
        env = RNNPolicyUnittestEnv(
            batch_size, steps_per_episode, gap, obs_dim=3)
        env = TFPyEnvironment(env)

        common.set_global_env(env)

        algorithm = _create_merlin_algorithm(
            learning_rate=1e-3, debug_summaries=False)
        driver = OnPolicyDriver(env, algorithm, train_interval=6)

        eval_driver = OnPolicyDriver(env, algorithm, training=False)

        proc = psutil.Process(os.getpid())

        policy_state = driver.get_initial_policy_state()
        time_step = driver.get_initial_time_step()
        for i in range(100):
            t0 = time.time()
            time_step, policy_state, _ = driver.run(
                max_num_steps=150 * batch_size,
                time_step=time_step,
                policy_state=policy_state)
            mem = proc.memory_info().rss // 1e6
            logging.info('%s time=%.3f mem=%s' % (i, time.time() - t0, mem))

        env.reset()
        time_step, _ = eval_driver.run(max_num_steps=14 * batch_size)
        logging.info("eval reward=%.3f" % tf.reduce_mean(time_step.reward))
        self.assertAlmostEqual(
            1.0, float(tf.reduce_mean(time_step.reward)), delta=1e-2)
示例#6
0
文件: play.py 项目: ruizhaogit/alf
def main(_):
    gin_file = common.get_gin_file()
    gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param)
    algorithm_ctor = gin.query_parameter(
        'TrainerConfig.algorithm_ctor').scoped_configurable_fn
    env = create_environment(nonparallel=True)
    common.set_global_env(env)
    algorithm = algorithm_ctor()
    policy_trainer.play(FLAGS.root_dir,
                        env,
                        algorithm,
                        checkpoint_name=FLAGS.checkpoint_name,
                        greedy_predict=FLAGS.greedy_predict,
                        random_seed=FLAGS.random_seed,
                        num_episodes=FLAGS.num_episodes,
                        sleep_time_per_step=FLAGS.sleep_time_per_step,
                        record_file=FLAGS.record_file)
    env.pyenv.close()
示例#7
0
文件: play.py 项目: runjerry/alf
def main(_):
    seed = common.set_random_seed(FLAGS.random_seed,
                                  not FLAGS.use_tf_functions)
    gin_file = common.get_gin_file()
    gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param)
    algorithm_ctor = gin.query_parameter(
        'TrainerConfig.algorithm_ctor').scoped_configurable_fn
    env = create_environment(nonparallel=True, seed=seed)
    env.reset()
    common.set_global_env(env)
    algorithm = algorithm_ctor(observation_spec=env.observation_spec(),
                               action_spec=env.action_spec())
    policy_trainer.play(FLAGS.root_dir,
                        env,
                        algorithm,
                        checkpoint_name=FLAGS.checkpoint_name,
                        epsilon_greedy=FLAGS.epsilon_greedy,
                        num_episodes=FLAGS.num_episodes,
                        sleep_time_per_step=FLAGS.sleep_time_per_step,
                        record_file=FLAGS.record_file,
                        use_tf_functions=FLAGS.use_tf_functions)
    env.pyenv.close()
示例#8
0
    def test_off_policy_algorithm(self, algorithm_ctor, use_rollout_state,
                                  sync_driver):
        logging.info("{} {}".format(algorithm_ctor.__name__, sync_driver))

        batch_size = 128
        if use_rollout_state:
            steps_per_episode = 5
            mini_batch_length = 8
            unroll_length = 8
            env_class = RNNPolicyUnittestEnv
        else:
            steps_per_episode = 12
            mini_batch_length = 2
            unroll_length = 12
            env_class = PolicyUnittestEnv
        env = TFPyEnvironment(
            env_class(
                batch_size,
                steps_per_episode,
                action_type=ActionType.Continuous))

        eval_env = TFPyEnvironment(
            env_class(
                batch_size,
                steps_per_episode,
                action_type=ActionType.Continuous))

        common.set_global_env(env)
        algorithm = algorithm_ctor()
        algorithm.set_summary_settings(summarize_grads_and_vars=True)
        algorithm.use_rollout_state = use_rollout_state

        if sync_driver:
            driver = SyncOffPolicyDriver(env, algorithm)
        else:
            driver = AsyncOffPolicyDriver([env],
                                          algorithm,
                                          num_actor_queues=1,
                                          unroll_length=unroll_length,
                                          learn_queue_cap=1,
                                          actor_queue_cap=1)
        eval_driver = OnPolicyDriver(eval_env, algorithm, training=False)

        eval_env.reset()
        driver.start()
        if sync_driver:
            time_step = driver.get_initial_time_step()
            policy_state = driver.get_initial_policy_state()
            for i in range(5):
                time_step, policy_state = driver.run(
                    max_num_steps=batch_size * steps_per_episode,
                    time_step=time_step,
                    policy_state=policy_state)

        for i in range(500):
            if sync_driver:
                time_step, policy_state = driver.run(
                    max_num_steps=batch_size * mini_batch_length * 2,
                    time_step=time_step,
                    policy_state=policy_state)
                whole_replay_buffer_training = False
                clear_replay_buffer = False
            else:
                driver.run_async()
                whole_replay_buffer_training = True
                clear_replay_buffer = True

            driver.algorithm.train(
                mini_batch_size=128,
                mini_batch_length=mini_batch_length,
                whole_replay_buffer_training=whole_replay_buffer_training,
                clear_replay_buffer=clear_replay_buffer)
            eval_env.reset()
            eval_time_step, _ = eval_driver.run(
                max_num_steps=(steps_per_episode - 1) * batch_size)
            logging.log_every_n_seconds(
                logging.INFO,
                "%d reward=%f" %
                (i, float(tf.reduce_mean(eval_time_step.reward))),
                n_seconds=1)
        driver.stop()

        self.assertAlmostEqual(
            1.0, float(tf.reduce_mean(eval_time_step.reward)), delta=2e-1)
示例#9
0
    def __init__(self, config: TrainerConfig):
        """

        Args:
            config (TrainerConfig): configuration used to construct this trainer
        """
        super().__init__(config)

        self._envs = []
        self._num_env_steps = config.num_env_steps
        self._num_iterations = config.num_iterations
        assert (self._num_iterations + self._num_env_steps > 0
                and self._num_iterations * self._num_env_steps == 0), \
            "Must provide #iterations or #env_steps exclusively for training!"
        self._trainer_progress.set_termination_criterion(
            self._num_iterations, self._num_env_steps)

        self._num_eval_episodes = config.num_eval_episodes
        alf.summary.should_summarize_output(config.summarize_output)

        env = self._create_environment(random_seed=self._random_seed)
        logging.info("observation_spec=%s" %
                     pprint.pformat(env.observation_spec()))
        logging.info("action_spec=%s" % pprint.pformat(env.action_spec()))
        common.set_global_env(env)

        data_transformer = create_data_transformer(
            config.data_transformer_ctor, env.observation_spec())
        self._config.data_transformer = data_transformer
        observation_spec = data_transformer.transformed_observation_spec
        common.set_transformed_observation_spec(observation_spec)

        self._algorithm = self._algorithm_ctor(
            observation_spec=observation_spec,
            action_spec=env.action_spec(),
            env=env,
            config=self._config,
            debug_summaries=self._debug_summaries)

        # Create an unwrapped env to expose subprocess gin confs which otherwise
        # will be marked as "inoperative". This env should be created last.
        # DO NOT register this env in self._envs because AsyncOffPolicyTrainer
        # will use all self._envs to init AsyncOffPolicyDriver!
        if self._evaluate or isinstance(
                env,
                alf.environments.parallel_environment.ParallelAlfEnvironment):
            self._unwrapped_env = self._create_environment(
                nonparallel=True,
                random_seed=self._random_seed,
                register=False)
        else:
            self._unwrapped_env = None
        self._eval_env = None
        self._eval_metrics = None
        self._eval_summary_writer = None
        if self._evaluate:
            self._eval_env = self._unwrapped_env
            self._eval_metrics = [
                alf.metrics.AverageReturnMetric(
                    buffer_size=self._num_eval_episodes,
                    reward_shape=self._eval_env.reward_spec().shape),
                alf.metrics.AverageEpisodeLengthMetric(
                    buffer_size=self._num_eval_episodes),
                alf.metrics.AverageEnvInfoMetric(
                    example_env_info=self._eval_env.reset().env_info,
                    batch_size=self._eval_env.batch_size,
                    buffer_size=self._num_eval_episodes)
            ]
            self._eval_summary_writer = alf.summary.create_summary_writer(
                self._eval_dir, flush_secs=config.summaries_flush_secs)