def initialize(self): """Initializes the Trainer.""" tf.random.set_seed(self._random_seed) tf.config.experimental_run_functions_eagerly( not self._use_tf_functions) self._env = create_environment() if self._evaluate: self._eval_env = create_environment(num_parallel_environments=1) self._algorithm = self._algorithm_ctor( self._env, debug_summaries=self._debug_summaries) self._driver = self.init_driver()
def main(_): seed = common.set_random_seed(FLAGS.random_seed) gin_file = common.get_gin_file() gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param) algorithm_ctor = gin.query_parameter( 'TrainerConfig.algorithm_ctor').scoped_configurable_fn env = create_environment(nonparallel=True, seed=seed) env.reset() common.set_global_env(env) config = policy_trainer.TrainerConfig(root_dir="") data_transformer = create_data_transformer(config.data_transformer_ctor, env.observation_spec()) config.data_transformer = data_transformer observation_spec = data_transformer.transformed_observation_spec common.set_transformed_observation_spec(observation_spec) algorithm = algorithm_ctor( observation_spec=observation_spec, action_spec=env.action_spec(), config=config) try: policy_trainer.play( FLAGS.root_dir, env, algorithm, checkpoint_step=FLAGS.checkpoint_step or "latest", epsilon_greedy=FLAGS.epsilon_greedy, num_episodes=FLAGS.num_episodes, max_episode_length=FLAGS.max_episode_length, sleep_time_per_step=FLAGS.sleep_time_per_step, record_file=FLAGS.record_file, ignored_parameter_prefixes=FLAGS.ignored_parameter_prefixes.split( ",") if FLAGS.ignored_parameter_prefixes else []) finally: env.close()
def _create_environment(self, nonparallel=False, random_seed=None, register=True): """Create and register an env.""" env = create_environment(nonparallel=nonparallel, seed=random_seed) if register: self._register_env(env) return env
def test_curriculum_wrapper(self): task_names = ['CartPole-v0', 'CartPole-v1'] env = create_environment( env_name=task_names, env_load_fn=suite_gym.load, num_parallel_environments=4, batched_wrappers=(alf_wrappers.CurriculumWrapper, )) self.assertTrue(type(env.action_spec()) == alf.BoundedTensorSpec) self.assertEqual(env.num_tasks, 2) self.assertEqual(len(env.env_info_spec()['curriculum_task_count']), 2) self.assertEqual(len(env.env_info_spec()['curriculum_task_score']), 2) self.assertEqual(len(env.env_info_spec()['curriculum_task_prob']), 2) for i in task_names: self.assertEqual(env.env_info_spec()['curriculum_task_count'][i], alf.TensorSpec(())) self.assertEqual(env.env_info_spec()['curriculum_task_score'][i], alf.TensorSpec(())) self.assertEqual(env.env_info_spec()['curriculum_task_prob'][i], alf.TensorSpec(())) time_step = env.reset() self.assertEqual(len(env.env_info_spec()['curriculum_task_count']), 2) self.assertEqual(len(env.env_info_spec()['curriculum_task_score']), 2) self.assertEqual(len(env.env_info_spec()['curriculum_task_prob']), 2) for i in task_names: self.assertEqual( time_step.env_info['curriculum_task_count'][i].shape, (4, )) self.assertEqual( time_step.env_info['curriculum_task_score'][i].shape, (4, )) self.assertEqual( time_step.env_info['curriculum_task_prob'][i].shape, (4, )) for j in range(500): time_step = env.step(time_step.prev_action) self.assertEqual(time_step.env_id, torch.arange(4)) self.assertEqual(len(env.env_info_spec()['curriculum_task_count']), 2) self.assertEqual(len(env.env_info_spec()['curriculum_task_score']), 2) self.assertEqual(len(env.env_info_spec()['curriculum_task_prob']), 2) for i in task_names: self.assertEqual( time_step.env_info['curriculum_task_count'][i].shape, (4, )) self.assertEqual( time_step.env_info['curriculum_task_score'][i].shape, (4, )) self.assertEqual( time_step.env_info['curriculum_task_prob'][i].shape, (4, )) sum_probs = sum( time_step.env_info['curriculum_task_prob'].values()) self.assertTrue( torch.all((sum_probs == 0.) | ((sum_probs - 1.).abs() < 1e-3)))
def initialize(self): """Initializes the Trainer.""" tf.random.set_seed(self._random_seed) tf.config.experimental_run_functions_eagerly( not self._use_tf_functions) self._env = create_environment() if self._evaluate: self._eval_env = create_environment(num_parallel_environments=1) self._algorithm = self._algorithm_ctor( self._env, debug_summaries=self._debug_summaries) if self._config.use_rollout_state: try: tf.nest.assert_same_structure( self._algorithm.train_state_spec, self._algorithm.predict_state_spec) except TypeError: self._algorithm = TrainStepAdapter(self._algorithm) self._driver = self.init_driver()
def init_driver(self): envs = [self._env] for i in range(1, self._config.num_envs): envs.append(create_environment()) driver = AsyncOffPolicyDriver( envs=envs, algorithm=self._algorithm, use_rollout_state=self._config.use_rollout_state, unroll_length=self._unroll_length, debug_summaries=self._debug_summaries, summarize_grads_and_vars=self._summarize_grads_and_vars) return driver
def play(root_dir, algorithm_ctor): """Play using the latest checkpoint under `train_dir`. Args: root_dir (str): directory where checkpoints stores algorithm_ctor (Callable): callable that create an algorithm parameter value is bind with `Trainer.algorithm_ctor`, just config `Trainer.algorithm_ctor` when using with gin configuration """ env = create_environment(num_parallel_environments=1) algorithm = algorithm_ctor(env) policy_trainer.play(root_dir, env, algorithm)
def main(_): gin_file = common.get_gin_file() gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param) algorithm_ctor = gin.query_parameter('TrainerConfig.algorithm_ctor') env = create_environment(num_parallel_environments=1) algorithm = algorithm_ctor(env) policy_trainer.play(FLAGS.root_dir, env, algorithm, checkpoint_name=FLAGS.checkpoint_name, greedy_predict=FLAGS.greedy_predict, random_seed=FLAGS.random_seed, num_episodes=FLAGS.num_episodes, sleep_time_per_step=FLAGS.sleep_time_per_step, record_file=FLAGS.record_file)
def main(_): seed = common.set_random_seed(FLAGS.random_seed, not FLAGS.use_tf_functions) gin_file = common.get_gin_file() gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param) algorithm_ctor = gin.query_parameter( 'TrainerConfig.algorithm_ctor').scoped_configurable_fn env = create_environment(nonparallel=True, seed=seed) env.reset() common.set_global_env(env) algorithm = algorithm_ctor(observation_spec=env.observation_spec(), action_spec=env.action_spec()) policy_trainer.play(FLAGS.root_dir, env, algorithm, checkpoint_name=FLAGS.checkpoint_name, epsilon_greedy=FLAGS.epsilon_greedy, num_episodes=FLAGS.num_episodes, sleep_time_per_step=FLAGS.sleep_time_per_step, record_file=FLAGS.record_file, use_tf_functions=FLAGS.use_tf_functions) env.pyenv.close()
def _create_environment(self, nonparallel=False): """Create and register an env.""" env = create_environment(nonparallel=nonparallel) self._register_env(env) return env