def play_env(self, env=None, nsteps=100, base_env_name=None, batch_size=5, reward_range=None): """Creates `EnvProblem` with the given arguments and plays it randomly. Args: env: optional env. nsteps: plays the env randomly for nsteps. base_env_name: passed to EnvProblem's init. batch_size: passed to EnvProblem's init. reward_range: passed to EnvProblem's init. Returns: tuple of env_problem, number of trajectories done, number of trajectories done in the last step. """ if env is None: env = env_problem.EnvProblem(base_env_name=base_env_name, batch_size=batch_size, reward_range=reward_range) # Usually done by a registered subclass, we do this manually in the test. env.name = base_env_name # Reset all environments. env.reset() # Play for some steps to generate data. num_dones = 0 num_dones_in_last_step = 0 for _ in range(nsteps): # Sample actions. actions = np.stack( [env.action_space.sample() for _ in range(batch_size)]) # Step through it. _, _, dones, _ = env.step(actions) # Get the indices where we are done ... done_indices = env_problem_utils.done_indices(dones) # ... and reset those. env.reset(indices=done_indices) # count the number of dones we got, in this step and overall. num_dones_in_last_step = sum(dones) num_dones += num_dones_in_last_step return env, num_dones, num_dones_in_last_step
def test_registration_and_interaction_with_env_problem(self): batch_size = 5 # This ensures that registration has occurred. ep = registry.env_problem("tic_tac_toe_env_problem", batch_size=batch_size) ep.reset() num_done, num_lost, num_won, num_draw = 0, 0, 0, 0 nsteps = 100 for _ in range(nsteps): actions = np.stack( [ep.action_space.sample() for _ in range(batch_size)]) obs, rewards, dones, infos = ep.step(actions) # Assert that things are happening batchwise. self.assertEqual(batch_size, len(obs)) self.assertEqual(batch_size, len(rewards)) self.assertEqual(batch_size, len(dones)) self.assertEqual(batch_size, len(infos)) done_indices = env_problem_utils.done_indices(dones) ep.reset(done_indices) num_done += sum(dones) for r, d in zip(rewards, dones): if not d: continue if r == -1: num_lost += 1 elif r == 0: num_draw += 1 elif r == 1: num_won += 1 else: raise ValueError( "reward should be -1, 0, 1 but is {}".format(r)) # Assert that something got done atleast, without that the next assert is # meaningless. self.assertGreater(num_done, 0) # Assert that things are consistent. self.assertEqual(num_done, num_won + num_lost + num_draw)
def play_env_problem(env, policy_fn): """Plays an EnvProblem using a given policy function.""" trajectories = [trajectory.Trajectory() for _ in range(env.batch_size)] observations = env.reset() for (traj, observation) in zip(trajectories, observations): traj.add_time_step(observation=observation) done_so_far = np.array([False] * env.batch_size) while not np.all(done_so_far): padded_observations, _ = env.trajectories.observations_np( len_history_for_policy=None) actions = policy_fn(padded_observations) (observations, rewards, dones, _) = env.step(actions) for (traj, observation, action, reward, done) in zip(trajectories, observations, actions, rewards, dones): if not traj.done: traj.change_last_time_step(action=action) traj.add_time_step(observation=observation, raw_reward=reward, done=done) env.reset(indices=env_problem_utils.done_indices(dones)) done_so_far = np.logical_or(done_so_far, dones) return trajectories
def test_registration_and_interaction_with_env_problem(self): batch_size = 5 # This ensures that registration has occurred. ep = registry.env_problem("reacher_env_problem", batch_size=batch_size) ep.reset() num_done = 0 nsteps = 100 for _ in range(nsteps): actions = np.stack( [ep.action_space.sample() for _ in range(batch_size)]) obs, rewards, dones, infos = ep.step(actions) # Assert that things are happening batchwise. self.assertEqual(batch_size, len(obs)) self.assertEqual(batch_size, len(rewards)) self.assertEqual(batch_size, len(dones)) self.assertEqual(batch_size, len(infos)) done_indices = env_problem_utils.done_indices(dones) ep.reset(done_indices) num_done += sum(dones) # Assert that something got done atleast, self.assertGreater(num_done, 0)
def test_interaction_with_env(self): batch_size = 5 reward_range = (-1, 1) ep = env_problem.EnvProblem(base_env_name="KellyCoinflip-v0", batch_size=batch_size, reward_range=reward_range) # Resets all environments. ep.reset() # Let's play a few steps. nsteps = 100 num_trajectories_completed = 0 num_timesteps_completed = 0 # If batch_done_at_step[i] = j then it means that i^th env last got done at # step = j. batch_done_at_step = np.full(batch_size, -1) for i in range(nsteps): # Sample batch_size actions from the action space and stack them (since # that is the expected type). actions = np.stack( [ep.action_space.sample() for _ in range(batch_size)]) _, _, dones, _ = ep.step(actions) # Do the book-keeping on number of trajectories completed and expect that # it matches ep's completed number. num_done = sum(dones) num_trajectories_completed += num_done self.assertEqual(num_trajectories_completed, len(ep.trajectories.completed_trajectories)) # Get the indices where we are done ... done_indices = env_problem_utils.done_indices(dones) # ... and reset those. ep.reset(indices=done_indices) # If nothing got done, go on to the next step. if done_indices.size == 0: # i.e. this is an empty array. continue # See when these indices were last done and calculate how many time-steps # each one took to get done. num_timesteps_completed += sum(i + 1 - batch_done_at_step[done_indices]) batch_done_at_step[done_indices] = i # This should also match the number of time-steps completed given by ep. num_timesteps_completed_ep = sum( ct.num_time_steps for ct in ep.trajectories.completed_trajectories) self.assertEqual(num_timesteps_completed, num_timesteps_completed_ep) # Reset the trajectories. ep.trajectories.reset_batch_trajectories() self.assertEqual(0, len(ep.trajectories.completed_trajectories))
def collect_trajectories( env, policy_fn, n_trajectories=1, n_observations=None, max_timestep=None, reset=True, len_history_for_policy=32, boundary=32, state=None, temperature=1.0, rng=None, abort_fn=None, raw_trajectory=False, ): """Collect trajectories with the given policy net and behaviour. Args: env: A gym env interface, for now this is not-batched. policy_fn: Callable (observations(B,T+1), actions(B, T+1, C)) -> log-probabs(B, T+1, C, A). n_trajectories: int, number of trajectories. n_observations: int, number of non-terminal observations. NOTE: Exactly one of `n_trajectories` and `n_observations` should be None. max_timestep: int or None, the index of the maximum time-step at which we return the trajectory, None for ending a trajectory only when env returns done. reset: bool, true if we want to reset the envs. The envs are also reset if max_max_timestep is None or < 0 len_history_for_policy: int or None, the maximum history to keep for applying the policy on. If None, use the full history. boundary: int, pad the sequences to the multiples of this number. state: state for `policy_fn`. temperature: (float) temperature to sample action from policy_fn. rng: jax rng, splittable. abort_fn: callable, If not None, then at every env step call and abort the trajectory collection if it returns True, if so reset the env and return None. raw_trajectory: bool, if True a list of trajectory.Trajectory objects is returned, otherwise a list of numpy representations of `trajectory.Trajectory` is returned. Returns: A tuple (trajectory, number of trajectories that are done) trajectory: list of (observation, action, reward) tuples, where each element `i` is a tuple of numpy arrays with shapes as follows: observation[i] = (B, T_i + 1) action[i] = (B, T_i) reward[i] = (B, T_i) """ assert isinstance(env, env_problem.EnvProblem) # We need to reset all environments, if we're coming here the first time. if reset or max_timestep is None or max_timestep <= 0: env.reset() else: # Clear completed trajectories held internally. env.trajectories.clear_completed_trajectories() num_done_trajectories = 0 # The stopping criterion, returns True if we should stop. def should_stop(): if n_trajectories is not None: assert n_observations is None return env.trajectories.num_completed_trajectories >= n_trajectories assert n_observations is not None # The number of non-terminal observations is what we want. return (env.trajectories.num_completed_time_steps - env.trajectories.num_completed_trajectories) >= n_observations policy_application_total_time = 0 env_actions_total_time = 0 bare_env_run_time = 0 while not should_stop(): # Check if we should abort and return nothing. if abort_fn and abort_fn(): # We should also reset the environment, since it will have some # trajectories (complete and incomplete) that we want to discard. env.reset() return None, 0, {}, state # Get all the observations for all the active trajectories. # Shape is (B, T+1) + OBS # Bucket on whatever length is needed. padded_observations, lengths = env.trajectories.observations_np( boundary=boundary, len_history_for_policy=len_history_for_policy) B = padded_observations.shape[0] # pylint: disable=invalid-name assert B == env.batch_size assert (B, ) == lengths.shape t1 = time.time() log_probs, value_preds, state, rng = policy_fn(padded_observations, lengths, state=state, rng=rng) policy_application_total_time += (time.time() - t1) assert B == log_probs.shape[0] actions = tl.gumbel_sample(log_probs, temperature) if (isinstance(env.action_space, gym.spaces.Discrete) and (actions.shape[1] == 1)): actions = onp.squeeze(actions, axis=1) # Step through the env. t1 = time.time() _, _, dones, env_infos = env.step(actions, infos={ 'log_prob_actions': log_probs, 'value_predictions': value_preds, }) env_actions_total_time += (time.time() - t1) bare_env_run_time += sum(info['__bare_env_run_time__'] for info in env_infos) # Count the number of done trajectories, the others could just have been # truncated. num_done_trajectories += onp.sum(dones) # Get the indices where we are done ... done_idxs = env_problem_utils.done_indices(dones) # ... and reset those. t1 = time.time() if done_idxs.size: env.reset(indices=done_idxs) env_actions_total_time += (time.time() - t1) if max_timestep is None or max_timestep < 1: continue # Are there any trajectories that have exceeded the time-limit we want. lengths = env.trajectories.trajectory_lengths exceeded_time_limit_idxs = env_problem_utils.done_indices( lengths > max_timestep) # If so, reset these as well. t1 = time.time() if exceeded_time_limit_idxs.size: # This just cuts the trajectory, doesn't reset the env, so it continues # from where it left off. env.truncate(indices=exceeded_time_limit_idxs, num_to_keep=1) env_actions_total_time += (time.time() - t1) # We have the trajectories we need, return a list of triples: # (observations, actions, rewards) completed_trajectories = ( env_problem_utils.get_completed_trajectories_from_env( env, env.trajectories.num_completed_trajectories, raw_trajectory=raw_trajectory)) timing_info = { 'trajectory_collection/policy_application': policy_application_total_time, 'trajectory_collection/env_actions': env_actions_total_time, 'trajectory_collection/env_actions/bare_env': bare_env_run_time, } timing_info = {k: round(1000 * v, 2) for k, v in timing_info.items()} return completed_trajectories, num_done_trajectories, timing_info, state