def test_new_time_step(sample_data): s = TimeStep(**sample_data) assert s.env_spec is sample_data['env_spec'] assert s.observation is sample_data['observation'] assert s.action is sample_data['action'] assert s.reward is sample_data['reward'] assert s.terminal is sample_data['terminal'] assert s.env_info is sample_data['env_info'] assert s.agent_info is sample_data['agent_info'] del s obs_space = akro.Box(low=-1, high=10, shape=(4, 3, 2), dtype=np.float32) act_space = akro.Box(low=-1, high=10, shape=(4, 2), dtype=np.float32) env_spec = EnvSpec(obs_space, act_space) sample_data['env_spec'] = env_spec obs_space = akro.Box(low=-1000, high=1000, shape=(4, 3, 2), dtype=np.float32) act_space = akro.Box(low=-1000, high=1000, shape=(4, 2), dtype=np.float32) sample_data['observation'] = obs_space.sample() sample_data['next_observation'] = obs_space.sample() sample_data['action'] = act_space.sample() s = TimeStep(**sample_data) assert s.observation is sample_data['observation'] assert s.next_observation is sample_data['next_observation'] assert s.action is sample_data['action']
def test_act_env_spec_mismatch_time_step(sample_data): with pytest.raises(ValueError, match='action must conform to action_space'): sample_data['action'] = sample_data['action'][:-1] s = TimeStep(**sample_data) del s obs_space = akro.Box(low=1, high=10, shape=(4, 3, 2), dtype=np.float32) act_space = akro.Discrete(5) env_spec = EnvSpec(obs_space, act_space) sample_data['env_spec'] = env_spec with pytest.raises(ValueError, match='action should have the same dimensionality'): sample_data['action'] = sample_data['action'][:-1] s = TimeStep(**sample_data) del s
def test_next_obs_env_spec_mismatch_time_step(sample_data): with pytest.raises( ValueError, match='next_observation must conform to observation_space'): sample_data['next_observation'] = sample_data[ 'next_observation'][:, :, :1] s = TimeStep(**sample_data) del s
def test_obs_env_spec_mismatch_time_step(sample_data): with pytest.raises(ValueError, match='observation must conform to observation_space'): sample_data['observation'] = sample_data['observation'][:, :, :1] s = TimeStep(**sample_data) del s obs_space = akro.Box(low=1, high=10, shape=(4, 5, 2), dtype=np.float32) act_space = gym.spaces.MultiDiscrete([2, 5]) env_spec = EnvSpec(obs_space, act_space) sample_data['env_spec'] = env_spec with pytest.raises( ValueError, match='observation should have the same dimensionality'): sample_data['observation'] = sample_data['observation'][:, :, :1] s = TimeStep(**sample_data) del s
def test_new_time_step(sample_data): s = TimeStep(**sample_data) assert s.env_spec is sample_data['env_spec'] assert s.observation is sample_data['observation'] assert s.action is sample_data['action'] assert s.reward is sample_data['reward'] assert s.terminal is sample_data['terminal'] assert s.env_info is sample_data['env_info'] assert s.agent_info is sample_data['agent_info']
def rollout_generator(env, agent, *, max_path_length=np.inf, animated=False, speedup=1, deterministic=False): """Sample a single rollout of the agent in the environment. Args: agent(Policy): Agent used to select actions. env(gym.Env): Environment to perform actions in. max_path_length(int): If the rollout reaches this many timesteps, it is terminated. animated(bool): If true, render the environment after each step. speedup(float): Factor by which to decrease the wait time between rendered steps. Only relevant, if animated == true. deterministic (bool): If true, use the mean action returned by the stochastic policy instead of sampling from the returned action distribution. Yields: MetaRL.TimeStep """ agent.reset() path_length = 0 last_observation = env.reset() if animated: env.render() while path_length < max_path_length: a, agent_info = agent.get_action(last_observation) if deterministic: a = agent_info['mean'] next_o, r, d, env_info = env.step(a) if r == 0: r = float(r) yield TimeStep(env_spec=env.spec, observation=last_observation, action=a, reward=r, next_observation=next_o, terminal=d, env_info=env_info, agent_info=agent_info) last_observation = next_o path_length += 1 if d: break if animated: env.render() timestep = 0.05 time.sleep(timestep / speedup)
def step_rollout(self): """Take a single time-step in the current rollout. Returns: bool: True iff the path is done, either due to the environment indicating termination of due to reaching `max_path_length`. """ if self._path_length < self._max_path_length: a, agent_info = self.agent.get_action(self._prev_obs) if self._deterministic: a = agent_info['mean'] next_o, r, d, env_info = self.env.step(a) self._observations.append(self._prev_obs) self._rewards.append(r) self._actions.append(a) for k, v in agent_info.items(): self._agent_infos[k].append(v) for k, v in env_info.items(): self._env_infos[k].append(v) self._path_length += 1 self._terminals.append(d) if self._accum_context: s = TimeStep(env_spec=self.env, observation=self._prev_obs, next_observation=next_o, action=a, reward=float(r), terminal=d, env_info=env_info, agent_info=agent_info) self.agent.update_context(s) if not d: self._prev_obs = next_o return False self._lengths.append(self._path_length) self._last_observations.append(self._prev_obs) return True
def test_env_info_dtype_mismatch_time_step(sample_data): with pytest.raises(ValueError, match='env_info must be type'): sample_data['env_info'] = [] s = TimeStep(**sample_data) del s
def test_terminal_dtype_mismatch_time_step(sample_data): with pytest.raises(ValueError, match='terminal must be dtype bool'): sample_data['terminal'] = [] s = TimeStep(**sample_data) del s
def test_reward_dtype_mismatch_time_step(sample_data): with pytest.raises(ValueError, match='reward must be type'): sample_data['reward'] = [] s = TimeStep(**sample_data) del s
def test_act_env_spec_mismatch_time_step(sample_data): with pytest.raises(ValueError, match='action must conform to action_space'): sample_data['action'] = sample_data['action'][:-1] s = TimeStep(**sample_data) del s