예제 #1
0
    def initialize(self):
        """Initializes the Trainer."""

        tf.random.set_seed(self._random_seed)
        tf.config.experimental_run_functions_eagerly(
            not self._use_tf_functions)
        self._env = create_environment()
        if self._evaluate:
            self._eval_env = create_environment(num_parallel_environments=1)
        self._algorithm = self._algorithm_ctor(
            self._env, debug_summaries=self._debug_summaries)
        self._driver = self.init_driver()
예제 #2
0
def main(_):
    seed = common.set_random_seed(FLAGS.random_seed)
    gin_file = common.get_gin_file()
    gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param)
    algorithm_ctor = gin.query_parameter(
        'TrainerConfig.algorithm_ctor').scoped_configurable_fn
    env = create_environment(nonparallel=True, seed=seed)
    env.reset()
    common.set_global_env(env)
    config = policy_trainer.TrainerConfig(root_dir="")
    data_transformer = create_data_transformer(config.data_transformer_ctor,
                                               env.observation_spec())
    config.data_transformer = data_transformer
    observation_spec = data_transformer.transformed_observation_spec
    common.set_transformed_observation_spec(observation_spec)
    algorithm = algorithm_ctor(
        observation_spec=observation_spec,
        action_spec=env.action_spec(),
        config=config)
    try:
        policy_trainer.play(
            FLAGS.root_dir,
            env,
            algorithm,
            checkpoint_step=FLAGS.checkpoint_step or "latest",
            epsilon_greedy=FLAGS.epsilon_greedy,
            num_episodes=FLAGS.num_episodes,
            max_episode_length=FLAGS.max_episode_length,
            sleep_time_per_step=FLAGS.sleep_time_per_step,
            record_file=FLAGS.record_file,
            ignored_parameter_prefixes=FLAGS.ignored_parameter_prefixes.split(
                ",") if FLAGS.ignored_parameter_prefixes else [])
    finally:
        env.close()
예제 #3
0
 def _create_environment(self,
                         nonparallel=False,
                         random_seed=None,
                         register=True):
     """Create and register an env."""
     env = create_environment(nonparallel=nonparallel, seed=random_seed)
     if register:
         self._register_env(env)
     return env
예제 #4
0
    def test_curriculum_wrapper(self):
        task_names = ['CartPole-v0', 'CartPole-v1']
        env = create_environment(
            env_name=task_names,
            env_load_fn=suite_gym.load,
            num_parallel_environments=4,
            batched_wrappers=(alf_wrappers.CurriculumWrapper, ))

        self.assertTrue(type(env.action_spec()) == alf.BoundedTensorSpec)

        self.assertEqual(env.num_tasks, 2)
        self.assertEqual(len(env.env_info_spec()['curriculum_task_count']), 2)
        self.assertEqual(len(env.env_info_spec()['curriculum_task_score']), 2)
        self.assertEqual(len(env.env_info_spec()['curriculum_task_prob']), 2)
        for i in task_names:
            self.assertEqual(env.env_info_spec()['curriculum_task_count'][i],
                             alf.TensorSpec(()))
            self.assertEqual(env.env_info_spec()['curriculum_task_score'][i],
                             alf.TensorSpec(()))
            self.assertEqual(env.env_info_spec()['curriculum_task_prob'][i],
                             alf.TensorSpec(()))

        time_step = env.reset()
        self.assertEqual(len(env.env_info_spec()['curriculum_task_count']), 2)
        self.assertEqual(len(env.env_info_spec()['curriculum_task_score']), 2)
        self.assertEqual(len(env.env_info_spec()['curriculum_task_prob']), 2)
        for i in task_names:
            self.assertEqual(
                time_step.env_info['curriculum_task_count'][i].shape, (4, ))
            self.assertEqual(
                time_step.env_info['curriculum_task_score'][i].shape, (4, ))
            self.assertEqual(
                time_step.env_info['curriculum_task_prob'][i].shape, (4, ))

        for j in range(500):
            time_step = env.step(time_step.prev_action)
            self.assertEqual(time_step.env_id, torch.arange(4))
            self.assertEqual(len(env.env_info_spec()['curriculum_task_count']),
                             2)
            self.assertEqual(len(env.env_info_spec()['curriculum_task_score']),
                             2)
            self.assertEqual(len(env.env_info_spec()['curriculum_task_prob']),
                             2)
            for i in task_names:
                self.assertEqual(
                    time_step.env_info['curriculum_task_count'][i].shape,
                    (4, ))
                self.assertEqual(
                    time_step.env_info['curriculum_task_score'][i].shape,
                    (4, ))
                self.assertEqual(
                    time_step.env_info['curriculum_task_prob'][i].shape, (4, ))
            sum_probs = sum(
                time_step.env_info['curriculum_task_prob'].values())
            self.assertTrue(
                torch.all((sum_probs == 0.) | ((sum_probs - 1.).abs() < 1e-3)))
예제 #5
0
    def initialize(self):
        """Initializes the Trainer."""

        tf.random.set_seed(self._random_seed)
        tf.config.experimental_run_functions_eagerly(
            not self._use_tf_functions)
        self._env = create_environment()
        if self._evaluate:
            self._eval_env = create_environment(num_parallel_environments=1)
        self._algorithm = self._algorithm_ctor(
            self._env, debug_summaries=self._debug_summaries)
        if self._config.use_rollout_state:
            try:
                tf.nest.assert_same_structure(
                    self._algorithm.train_state_spec,
                    self._algorithm.predict_state_spec)
            except TypeError:
                self._algorithm = TrainStepAdapter(self._algorithm)
        self._driver = self.init_driver()
예제 #6
0
 def init_driver(self):
     envs = [self._env]
     for i in range(1, self._config.num_envs):
         envs.append(create_environment())
     driver = AsyncOffPolicyDriver(
         envs=envs,
         algorithm=self._algorithm,
         use_rollout_state=self._config.use_rollout_state,
         unroll_length=self._unroll_length,
         debug_summaries=self._debug_summaries,
         summarize_grads_and_vars=self._summarize_grads_and_vars)
     return driver
예제 #7
0
def play(root_dir, algorithm_ctor):
    """Play using the latest checkpoint under `train_dir`.

    Args:
        root_dir (str): directory where checkpoints stores
        algorithm_ctor (Callable): callable that create an algorithm
            parameter value is bind with `Trainer.algorithm_ctor`,
            just config `Trainer.algorithm_ctor` when using with gin configuration
    """
    env = create_environment(num_parallel_environments=1)
    algorithm = algorithm_ctor(env)
    policy_trainer.play(root_dir, env, algorithm)
예제 #8
0
def main(_):
    gin_file = common.get_gin_file()
    gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param)
    algorithm_ctor = gin.query_parameter('TrainerConfig.algorithm_ctor')
    env = create_environment(num_parallel_environments=1)
    algorithm = algorithm_ctor(env)
    policy_trainer.play(FLAGS.root_dir,
                        env,
                        algorithm,
                        checkpoint_name=FLAGS.checkpoint_name,
                        greedy_predict=FLAGS.greedy_predict,
                        random_seed=FLAGS.random_seed,
                        num_episodes=FLAGS.num_episodes,
                        sleep_time_per_step=FLAGS.sleep_time_per_step,
                        record_file=FLAGS.record_file)
예제 #9
0
파일: play.py 프로젝트: runjerry/alf
def main(_):
    seed = common.set_random_seed(FLAGS.random_seed,
                                  not FLAGS.use_tf_functions)
    gin_file = common.get_gin_file()
    gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param)
    algorithm_ctor = gin.query_parameter(
        'TrainerConfig.algorithm_ctor').scoped_configurable_fn
    env = create_environment(nonparallel=True, seed=seed)
    env.reset()
    common.set_global_env(env)
    algorithm = algorithm_ctor(observation_spec=env.observation_spec(),
                               action_spec=env.action_spec())
    policy_trainer.play(FLAGS.root_dir,
                        env,
                        algorithm,
                        checkpoint_name=FLAGS.checkpoint_name,
                        epsilon_greedy=FLAGS.epsilon_greedy,
                        num_episodes=FLAGS.num_episodes,
                        sleep_time_per_step=FLAGS.sleep_time_per_step,
                        record_file=FLAGS.record_file,
                        use_tf_functions=FLAGS.use_tf_functions)
    env.pyenv.close()
예제 #10
0
 def _create_environment(self, nonparallel=False):
     """Create and register an env."""
     env = create_environment(nonparallel=nonparallel)
     self._register_env(env)
     return env