def _eval(self): self._algorithm.eval() time_step = common.get_initial_time_step(self._eval_env) policy_state = self._algorithm.get_initial_predict_state( self._eval_env.batch_size) trans_state = self._algorithm.get_initial_transform_state( self._eval_env.batch_size) episodes = 0 while episodes < self._num_eval_episodes: time_step, policy_step, trans_state = _step( algorithm=self._algorithm, env=self._eval_env, time_step=time_step, policy_state=policy_state, trans_state=trans_state, epsilon_greedy=self._config.epsilon_greedy, metrics=self._eval_metrics) policy_state = policy_step.state if time_step.is_last(): episodes += 1 step_metrics = self._algorithm.get_step_metrics() with alf.summary.push_summary_writer(self._eval_summary_writer): for metric in self._eval_metrics: metric.gen_summaries( train_step=alf.summary.get_global_counter(), step_metrics=step_metrics) common.log_metrics(self._eval_metrics) self._algorithm.train()
def _test_off_policy_algorithm(self, root_dir): alf.summary.enable_summary() config = TrainerConfig(root_dir=root_dir, unroll_length=5, num_envs=1, num_updates_per_train_iter=1, mini_batch_length=5, mini_batch_size=3, use_rollout_state=True, summarize_grads_and_vars=True, summarize_action_distributions=True, whole_replay_buffer_training=True) env = MyEnv(batch_size=3) alg = MyAlg(observation_spec=env.observation_spec(), action_spec=env.action_spec(), env=env, on_policy=False, config=config) for _ in range(100): alg.train_iter() time_step = common.get_initial_time_step(env) state = alg.get_initial_predict_state(env.batch_size) policy_step = alg.rollout_step(time_step, state) logits = policy_step.info.log_prob(torch.arange(3).reshape(3, 1)) print("logits: ", logits) self.assertTrue(torch.all(logits[1, :] > logits[0, :])) self.assertTrue(torch.all(logits[1, :] > logits[2, :]))
def _run(self, coord, unroll_length): with coord.stop_on_exception(): time_step = common.get_initial_time_step(self._env) policy_state = self._initial_policy_state while not coord.should_stop(): time_step, policy_state = self._unroll_and_learn( time_step, policy_state, unroll_length) # Whoever stops first, cancel all pending requests # (including enqueues and dequeues), # so that no thread hangs before calling coord.should_stop() self._tfq.close_all()
def test_rl_trainer(self): with tempfile.TemporaryDirectory() as root_dir: conf = TrainerConfig(algorithm_ctor=MyAlg, root_dir=root_dir, unroll_length=5, num_iterations=100) # test train trainer = MyRLTrainer(conf) self.assertEqual(RLTrainer.progress(), 0) trainer.train() self.assertEqual(RLTrainer.progress(), 1) alg = trainer._algorithm env = common.get_env() time_step = common.get_initial_time_step(env) state = alg.get_initial_predict_state(env.batch_size) policy_step = alg.rollout_step(time_step, state) logits = policy_step.info.logits print("logits: ", logits) self.assertTrue(torch.all(logits[:, 1] > logits[:, 0])) self.assertTrue(torch.all(logits[:, 1] > logits[:, 2])) # test checkpoint conf.num_iterations = 200 new_trainer = MyRLTrainer(conf) new_trainer._restore_checkpoint() self.assertEqual(RLTrainer.progress(), 0.5) time_step = common.get_initial_time_step(env) state = alg.get_initial_predict_state(env.batch_size) policy_step = alg.rollout_step(time_step, state) logits = policy_step.info.logits self.assertTrue(torch.all(logits[:, 1] > logits[:, 0])) self.assertTrue(torch.all(logits[:, 1] > logits[:, 2])) new_trainer.train() self.assertEqual(RLTrainer.progress(), 1)
def unroll(env, algorithm, steps, epsilon_greedy=0.1): """Run `steps` environment steps using algoirthm.predict_step().""" time_step = common.get_initial_time_step(env) policy_state = algorithm.get_initial_predict_state(env.batch_size) trans_state = algorithm.get_initial_transform_state(env.batch_size) for _ in range(steps): policy_state = common.reset_state_if_necessary( policy_state, algorithm.get_initial_predict_state(env.batch_size), time_step.is_first()) transformed_time_step, trans_state = algorithm.transform_timestep( time_step, trans_state) policy_step = algorithm.predict_step(transformed_time_step, policy_state, epsilon_greedy=epsilon_greedy) time_step = env.step(policy_step.output) policy_state = policy_step.state return time_step
def test_ac_algorithm(self): env = MyEnv(batch_size=3) alg1 = create_algorithm(env) iter_num = 50 for _ in range(iter_num): alg1.train_iter() time_step = common.get_initial_time_step(env) state = alg1.get_initial_predict_state(env.batch_size) policy_step = alg1.rollout_step(time_step, state) logits = policy_step.info.action_distribution.log_prob( torch.arange(3).reshape(3, 1)) print("logits: ", logits) self.assertTrue(torch.all(logits[1, :] > logits[0, :])) self.assertTrue(torch.all(logits[1, :] > logits[2, :])) # global counter is iter_num due to alg1 self.assertTrue(alf.summary.get_global_counter() == iter_num)
def _run(self, coord, unroll_length): with coord.stop_on_exception(): try: time_step = common.get_initial_time_step( self._env, first_env_id=self._first_env_id) policy_state = self._initial_policy_state while not coord.should_stop(): time_step, policy_state = self._unroll_and_learn( time_step, policy_state, unroll_length) except tf.errors.CancelledError as e: raise e except tf.errors.OutOfRangeError as e: raise e except Exception as e: traceback.print_exc() raise e # Whoever stops first, cancel all pending requests # (including enqueues and dequeues), # so that no thread hangs before calling coord.should_stop() self._tfq.close_all()
def test_trac_algorithm(self): config = TrainerConfig(root_dir="dummy", unroll_length=5) env = MyEnv(batch_size=3) alg = TracAlgorithm(observation_spec=env.observation_spec(), action_spec=env.action_spec(), ac_algorithm_cls=create_ac_algorithm, env=env, config=config) for _ in range(50): alg.train_iter() time_step = common.get_initial_time_step(env) state = alg.get_initial_predict_state(env.batch_size) policy_step = alg.rollout_step(time_step, state) logits = policy_step.info.action_distribution.log_prob( torch.arange(3).reshape(3, 1)) print("logits: ", logits) # action 1 gets the most reward. So its probability should be higher # than other actions after training. self.assertTrue(torch.all(logits[1, :] > logits[0, :])) self.assertTrue(torch.all(logits[1, :] > logits[2, :]))
def test_on_policy_algorithm(self): # root_dir is not used. We have to give it a value because # it is a required argument of TrainerConfig. config = TrainerConfig(root_dir='/tmp/rl_algorithm_test', unroll_length=5, num_envs=1) env = MyEnv(batch_size=3) alg = MyAlg(observation_spec=env.observation_spec(), action_spec=env.action_spec(), env=env, config=config, on_policy=True, debug_summaries=True) for _ in range(100): alg.train_iter() time_step = common.get_initial_time_step(env) state = alg.get_initial_predict_state(env.batch_size) policy_step = alg.rollout_step(time_step, state) logits = policy_step.info.log_prob(torch.arange(3).reshape(3, 1)) print("logits: ", logits) self.assertTrue(torch.all(logits[1, :] > logits[0, :])) self.assertTrue(torch.all(logits[1, :] > logits[2, :]))
def get_initial_time_step(self): return common.get_initial_time_step(self._env)
def unroll(self, unroll_length): r"""Unroll ``unroll_length`` steps using the current policy. Because the ``self._env`` is a batched environment. The total number of environment steps is ``self._env.batch_size * unroll_length``. Args: unroll_length (int): number of steps to unroll. Returns: Experience: The stacked experience with shape :math:`[T, B, \ldots]` for each of its members. """ if self._current_time_step is None: self._current_time_step = common.get_initial_time_step(self._env) if self._current_policy_state is None: self._current_policy_state = self.get_initial_rollout_state( self._env.batch_size) if self._current_transform_state is None: self._current_transform_state = self.get_initial_transform_state( self._env.batch_size) time_step = self._current_time_step policy_state = self._current_policy_state trans_state = self._current_transform_state experience_list = [] initial_state = self.get_initial_rollout_state(self._env.batch_size) env_step_time = 0. store_exp_time = 0. for _ in range(unroll_length): policy_state = common.reset_state_if_necessary( policy_state, initial_state, time_step.is_first()) transformed_time_step, trans_state = self.transform_timestep( time_step, trans_state) # save the untransformed time step in case that sub-algorithms need # to store it in replay buffers transformed_time_step = transformed_time_step._replace( untransformed=time_step) policy_step = self.rollout_step(transformed_time_step, policy_state) # release the reference to ``time_step`` transformed_time_step = transformed_time_step._replace( untransformed=()) action = common.detach(policy_step.output) t0 = time.time() next_time_step = self._env.step(action) env_step_time += time.time() - t0 self.observe_for_metrics(time_step.cpu()) if self._exp_replayer_type == "one_time": exp = make_experience(transformed_time_step, policy_step, policy_state) else: exp = make_experience(time_step.cpu(), policy_step, policy_state) t0 = time.time() self.observe_for_replay(exp) store_exp_time += time.time() - t0 exp_for_training = Experience( action=action, reward=transformed_time_step.reward, discount=transformed_time_step.discount, step_type=transformed_time_step.step_type, state=policy_state, prev_action=transformed_time_step.prev_action, observation=transformed_time_step.observation, rollout_info=dist_utils.distributions_to_params( policy_step.info), env_id=transformed_time_step.env_id) experience_list.append(exp_for_training) time_step = next_time_step policy_state = policy_step.state alf.summary.scalar("time/unroll_env_step", env_step_time) alf.summary.scalar("time/unroll_store_exp", store_exp_time) experience = alf.nest.utils.stack_nests(experience_list) experience = experience._replace( rollout_info=dist_utils.params_to_distributions( experience.rollout_info, self._rollout_info_spec)) self._current_time_step = time_step # Need to detach so that the graph from this unroll is disconnected from # the next unroll. Otherwise backward() will report error for on-policy # training after the next unroll. self._current_policy_state = common.detach(policy_state) self._current_transform_state = common.detach(trans_state) return experience
def play(root_dir, env, algorithm, checkpoint_step="latest", epsilon_greedy=0., num_episodes=10, max_episode_length=0, sleep_time_per_step=0.01, record_file=None, future_steps=0, append_blank_frames=0, render=True, render_prediction=False, ignored_parameter_prefixes=[]): """Play using the latest checkpoint under `train_dir`. The following example record the play of a trained model to a mp4 video: .. code-block:: bash python -m alf.bin.play \ --root_dir=~/tmp/bullet_humanoid/ppo2/ppo2-11 \ --num_episodes=1 \ --record_file=ppo_bullet_humanoid.mp4 Args: root_dir (str): same as the root_dir used for `train()` env (AlfEnvironment): the environment algorithm (RLAlgorithm): the training algorithm checkpoint_step (int|str): the number of training steps which is used to specify the checkpoint to be loaded. If checkpoint_step is 'latest', the most recent checkpoint named 'latest' will be loaded. epsilon_greedy (float): a floating value in [0,1], representing the chance of action sampling instead of taking argmax. This can help prevent a dead loop in some deterministic environment like Breakout. num_episodes (int): number of episodes to play max_episode_length (int): if >0, each episode is limited to so many steps. sleep_time_per_step (float): sleep so many seconds for each step record_file (str): if provided, video will be recorded to a file instead of shown on the screen. future_steps (int): whether to encode some information from future steps into the current frame. If future_steps is larger than zero, then the related information (e.g. observation, reward, action etc.) will be cached and the encoding of them to video frames is deferred to the time when ``future_steps`` of future frames are available. This defer mode is potentially useful to display for each frame some information that expands beyond a single time step to the future. Currently this mode only support offline rendering, i.e. rendering and saving the video to ``record_file``. If a non-positive value is provided, it is treated as not using the defer mode and the plots for displaying future information will not be displayed. append_blank_frames (int): If >0, wil append such number of blank frames at the end of the episode in the rendered video file. A negative value has the same effects as 0 and no blank frames will be appended. This option has no effects when displaying the frames on the screen instead of recording to a file. render (bool): If False, then this function only evaluates the trained model without calling rendering functions. This value will be ignored if a ``record_file`` argument is provided. render_prediction (bool): If True, when using ``VideoRecorder`` to render a video, extra prediction info (returned by ``predict_step()``) will also be rendered by the side of video frames. ignored_parameter_prefixes (list[str]): ignore the parameters whose name has one of these prefixes in the checkpoint. """ root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') ckpt_dir = os.path.join(train_dir, 'algorithm') checkpointer = Checkpointer(ckpt_dir=ckpt_dir, algorithm=algorithm) checkpointer.load(checkpoint_step, ignored_parameter_prefixes=ignored_parameter_prefixes, including_optimizer=False, including_replay_buffer=False) recorder = None if record_file is not None: recorder = VideoRecorder(env, future_steps=future_steps, append_blank_frames=append_blank_frames, render_prediction=render_prediction, path=record_file) elif render: # pybullet_envs need to render() before reset() to enable mode='human' env.render(mode='human') env.reset() time_step = common.get_initial_time_step(env) algorithm.eval() policy_state = algorithm.get_initial_predict_state(env.batch_size) trans_state = algorithm.get_initial_transform_state(env.batch_size) episode_reward = 0. episode_length = 0 episodes = 0 metrics = [ alf.metrics.AverageReturnMetric(buffer_size=num_episodes, reward_shape=env.reward_spec().shape), alf.metrics.AverageEpisodeLengthMetric(buffer_size=num_episodes), ] while episodes < num_episodes: time_step, policy_step, trans_state = _step( algorithm=algorithm, env=env, time_step=time_step, policy_state=policy_state, trans_state=trans_state, epsilon_greedy=epsilon_greedy, metrics=metrics) policy_state = policy_step.state episode_length += 1 is_last_step = time_step.is_last() or (episode_length >= max_episode_length > 0) if recorder: recorder.capture_frame(time_step, policy_step, is_last_step) elif render: env.render(mode='human') time.sleep(sleep_time_per_step) time_step_reward = time_step.reward.view(-1).float().cpu().numpy() episode_reward += time_step_reward if is_last_step: logging.info("episode_length=%s episode_reward=%s" % (episode_length, episode_reward)) episode_reward = 0. episode_length = 0. episodes += 1 # observe the last step for m in metrics: m(time_step.cpu()) time_step = env.reset() for m in metrics: logging.info( "%s: %s", m.name, map_structure( lambda x: x.cpu().numpy().item() if x.ndim == 0 else x.cpu().numpy(), m.result())) if recorder: recorder.close() env.reset()
def play(root_dir, env, algorithm, checkpoint_step="latest", epsilon_greedy=0.1, num_episodes=10, max_episode_length=0, sleep_time_per_step=0.01, record_file=None, ignored_parameter_prefixes=[]): """Play using the latest checkpoint under `train_dir`. The following example record the play of a trained model to a mp4 video: .. code-block:: bash python -m alf.bin.play \ --root_dir=~/tmp/bullet_humanoid/ppo2/ppo2-11 \ --num_episodes=1 \ --record_file=ppo_bullet_humanoid.mp4 Args: root_dir (str): same as the root_dir used for `train()` env (AlfEnvironment): the environment algorithm (RLAlgorithm): the training algorithm checkpoint_step (int|str): the number of training steps which is used to specify the checkpoint to be loaded. If checkpoint_step is 'latest', the most recent checkpoint named 'latest' will be loaded. epsilon_greedy (float): a floating value in [0,1], representing the chance of action sampling instead of taking argmax. This can help prevent a dead loop in some deterministic environment like Breakout. num_episodes (int): number of episodes to play max_episode_length (int): if >0, each episode is limited to so many steps. sleep_time_per_step (float): sleep so many seconds for each step record_file (str): if provided, video will be recorded to a file instead of shown on the screen. ignored_parameter_prefixes (list[str]): ignore the parameters whose name has one of these prefixes in the checkpoint. """ root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') ckpt_dir = os.path.join(train_dir, 'algorithm') checkpointer = Checkpointer(ckpt_dir=ckpt_dir, algorithm=algorithm) checkpointer.load( checkpoint_step, ignored_parameter_prefixes=ignored_parameter_prefixes, including_optimizer=False, including_replay_buffer=False) recorder = None if record_file is not None: recorder = VideoRecorder(env, path=record_file) else: # pybullet_envs need to render() before reset() to enable mode='human' env.render(mode='human') env.reset() time_step = common.get_initial_time_step(env) algorithm.eval() policy_state = algorithm.get_initial_predict_state(env.batch_size) trans_state = algorithm.get_initial_transform_state(env.batch_size) episode_reward = 0. episode_length = 0 episodes = 0 metrics = [ alf.metrics.AverageReturnMetric( buffer_size=num_episodes, reward_shape=env.reward_spec().shape), alf.metrics.AverageEpisodeLengthMetric(buffer_size=num_episodes), ] while episodes < num_episodes: time_step, policy_state, trans_state, info = _step( algorithm=algorithm, env=env, time_step=time_step, policy_state=policy_state, trans_state=trans_state, epsilon_greedy=epsilon_greedy, metrics=metrics) episode_length += 1 if recorder: recorder.capture_frame(info) else: env.render(mode='human') time.sleep(sleep_time_per_step) time_step_reward = time_step.reward.view(-1).float().cpu().numpy() episode_reward += time_step_reward if time_step.is_last() or episode_length >= max_episode_length > 0: logging.info("episode_length=%s episode_reward=%s" % (episode_length, episode_reward)) episode_reward = 0. episode_length = 0. episodes += 1 # observe the last step for m in metrics: m(time_step.cpu()) time_step = env.reset() for m in metrics: logging.info("%s: %f", m.name, m.result()) if recorder: recorder.close() env.reset()