Пример #1
0
def test_new_time_step(sample_data):
    s = TimeStep(**sample_data)
    assert s.env_spec is sample_data['env_spec']
    assert s.observation is sample_data['observation']
    assert s.action is sample_data['action']
    assert s.reward is sample_data['reward']
    assert s.step_type is sample_data['step_type']
    assert s.env_info is sample_data['env_info']
    assert s.agent_info is sample_data['agent_info']
    del s

    obs_space = akro.Box(low=-1, high=10, shape=(4, 3, 2), dtype=np.float32)
    act_space = akro.Box(low=-1, high=10, shape=(4, 2), dtype=np.float32)
    env_spec = EnvSpec(obs_space, act_space)
    sample_data['env_spec'] = env_spec
    obs_space = akro.Box(low=-1000,
                         high=1000,
                         shape=(4, 3, 2),
                         dtype=np.float32)
    act_space = akro.Box(low=-1000, high=1000, shape=(4, 2), dtype=np.float32)
    sample_data['observation'] = obs_space.sample()
    sample_data['next_observation'] = obs_space.sample()
    sample_data['action'] = act_space.sample()
    s = TimeStep(**sample_data)

    assert s.observation is sample_data['observation']
    assert s.next_observation is sample_data['next_observation']
    assert s.action is sample_data['action']
Пример #2
0
def test_from_env_step_time_step(sample_data):
    agent_info = sample_data['agent_info']
    last_observation = sample_data['observation']
    observation = sample_data['next_observation']
    time_step = TimeStep(**sample_data)
    del sample_data['agent_info']
    del sample_data['next_observation']
    sample_data['observation'] = observation
    env_step = EnvStep(**sample_data)
    time_step_new = TimeStep.from_env_step(env_step=env_step,
                                           last_observation=last_observation,
                                           agent_info=agent_info)
    assert time_step == time_step_new
Пример #3
0
def test_step_type_property_time_step(sample_data):
    sample_data['step_type'] = StepType.FIRST
    s = TimeStep(**sample_data)
    assert s.first

    sample_data['step_type'] = StepType.MID
    s = TimeStep(**sample_data)
    assert s.mid

    sample_data['step_type'] = StepType.TERMINAL
    s = TimeStep(**sample_data)
    assert s.terminal and s.last

    sample_data['step_type'] = StepType.TIMEOUT
    s = TimeStep(**sample_data)
    assert s.timeout and s.last
Пример #4
0
    def step_episode(self):
        """Take a single time-step in the current episode.

        Returns:
            bool: True iff the episode is done, either due to the environment
            indicating termination of due to reaching `max_episode_length`.

        """
        if self._eps_length < self._max_episode_length:
            a, agent_info = self.agent.get_action(self._prev_obs)
            if self._deterministic:
                a = agent_info['mean']
            a, agent_info = self.agent.get_action(self._prev_obs)
            es = self.env.step(a)
            self._observations.append(self._prev_obs)
            self._env_steps.append(es)
            for k, v in agent_info.items():
                self._agent_infos[k].append(v)
            self._eps_length += 1

            if self._accum_context:
                s = TimeStep.from_env_step(env_step=es,
                                           last_observation=self._prev_obs,
                                           agent_info=agent_info,
                                           episode_info=self._episode_info)
                self.agent.update_context(s)
            if not es.last:
                self._prev_obs = es.observation
                return False
        self._lengths.append(self._eps_length)
        self._last_observations.append(self._prev_obs)
        return True
Пример #5
0
def test_act_env_spec_mismatch_time_step(sample_data):
    with pytest.raises(ValueError,
                       match='action must conform to action_space'):
        sample_data['action'] = sample_data['action'][:-1]
        s = TimeStep(**sample_data)
        del s

    obs_space = akro.Box(low=1, high=10, shape=(4, 3, 2), dtype=np.float32)
    act_space = akro.Discrete(5)
    env_spec = EnvSpec(obs_space, act_space)
    sample_data['env_spec'] = env_spec

    with pytest.raises(ValueError,
                       match='action should have the same dimensionality'):
        sample_data['action'] = sample_data['action'][:-1]
        s = TimeStep(**sample_data)
        del s
Пример #6
0
def test_next_obs_env_spec_mismatch_time_step(sample_data):
    with pytest.raises(
            ValueError,
            match='next_observation must conform to observation_space'):
        sample_data['next_observation'] = sample_data[
            'next_observation'][:, :, :1]
        s = TimeStep(**sample_data)
        del s
Пример #7
0
def test_obs_env_spec_mismatch_time_step(sample_data):
    with pytest.raises(ValueError,
                       match='observation must conform to observation_space'):
        sample_data['observation'] = sample_data['observation'][:, :, :1]
        s = TimeStep(**sample_data)
        del s

    obs_space = akro.Box(low=1, high=10, shape=(4, 5, 2), dtype=np.float32)
    act_space = gym.spaces.MultiDiscrete([2, 5])
    env_spec = EnvSpec(obs_space, act_space)
    sample_data['env_spec'] = env_spec

    with pytest.raises(
            ValueError,
            match='observation should have the same dimensionality'):
        sample_data['observation'] = sample_data['observation'][:, :, :1]
        s = TimeStep(**sample_data)
        del s
Пример #8
0
def test_new_time_step(sample_data):
    s = TimeStep(**sample_data)
    assert s.env_spec is sample_data['env_spec']
    assert s.observation is sample_data['observation']
    assert s.action is sample_data['action']
    assert s.reward is sample_data['reward']
    assert s.terminal is sample_data['terminal']
    assert s.env_info is sample_data['env_info']
    assert s.agent_info is sample_data['agent_info']
 def test_update_context(self):
     """Test update_context."""
     s = TimeStep(env_spec=self.env_spec,
                  observation=np.ones(self.obs_dim),
                  next_observation=np.ones(self.obs_dim),
                  action=np.ones(self.action_dim),
                  reward=1.0,
                  terminal=False,
                  env_info={},
                  agent_info={})
     updates = 10
     for _ in range(updates):
         self.module.update_context(s)
     assert torch.all(
         torch.eq(self.module.context,
                  torch.ones(updates, self.encoder_input_dim)))
Пример #10
0
    def step_rollout(self):
        """Take a single time-step in the current rollout.

        Returns:
            bool: True iff the path is done, either due to the environment
            indicating termination of due to reaching `max_episode_length`.

        """
        if self._path_length < self._max_episode_length:
            a, agent_info = self.agent.get_action(self._prev_obs)
            if self._deterministic:
                a = agent_info['mean']
            next_o, r, d, env_info = self.env.step(a)
            self._observations.append(self._prev_obs)
            self._rewards.append(r)
            self._actions.append(a)
            for k, v in agent_info.items():
                self._agent_infos[k].append(v)
            for k, v in env_info.items():
                self._env_infos[k].append(v)
            self._path_length += 1
            # Temp solution
            if d:
                self._step_types.append(StepType.TERMINAL)
            else:
                self._step_types.append(StepType.MID)

            if self._accum_context:
                # step_type should be extracted from TimeStep returned from
                # env.step(). The population of step_type should be updated
                # once env returns a TimeStep.
                s = TimeStep(env_spec=self.env,
                             observation=self._prev_obs,
                             next_observation=next_o,
                             action=a,
                             reward=float(r),
                             env_info=env_info,
                             agent_info=agent_info,
                             step_type=StepType.MID)
                self.agent.update_context(s)
            if not d:
                self._prev_obs = next_o
                return False
        self._lengths.append(self._path_length)
        self._last_observations.append(self._prev_obs)
        return True
Пример #11
0
    def step_rollout(self):
        """Take a single time-step in the current rollout.

        Returns:
            bool: True iff the path is done, either due to the environment
            indicating termination of due to reaching `max_path_length`.

        """
        if self._path_length < self._max_path_length:
            a, agent_info = self.agent.get_action(self._prev_obs)
            if self._deterministic:
                a = agent_info['mean']
            #time.sleep(.02)  # fix "mujoco_py.builder.MujocoException: Unknown warning type Time = 0.0000.Check for NaN in simulation."
            a[a != a] = 0
            next_o, r, d, env_info = self.env.step(a)
            self._observations.append(self._prev_obs)
            self._rewards.append(r)
            self._actions.append(a)
            for k, v in agent_info.items():
                self._agent_infos[k].append(v)
            for k, v in env_info.items():
                self._env_infos[k].append(v)
            self._path_length += 1
            self._terminals.append(d)
            if self._accum_context:
                s = TimeStep(env_spec=self.env,
                             observation=self._prev_obs,
                             next_observation=next_o,
                             action=a,
                             reward=float(r),
                             terminal=d,
                             env_info=env_info,
                             agent_info=agent_info)
                self.agent.update_context(s)
            if not d:
                self._prev_obs = next_o
                return False
        self._lengths.append(self._path_length)
        self._last_observations.append(self._prev_obs)
        return True
Пример #12
0
    def step_rollout(self):
        """Take a single time-step in the current rollout.

        Returns:
            bool: True iff the path is done, either due to the environment
            indicating termination of due to reaching `max_episode_length`.

        """
        if self._path_length < self._max_episode_length:
            a, agent_info = self.agent.get_action(self._prev_obs)
            if self._deterministic:
                a = agent_info['mean']
            next_o, r, d, env_info = self.env.step(a)
            self._observations.append(self._prev_obs)
            self._rewards.append(r)
            self._actions.append(a)
            for k, v in agent_info.items():
                self._agent_infos[k].append(v)
            for k, v in env_info.items():
                self._env_infos[k].append(v)
            self._path_length += 1
            self._terminals.append(d)
            if self._accum_context:
                s = TimeStep(env_spec=self.env,
                             observation=self._prev_obs,
                             next_observation=next_o,
                             action=a,
                             reward=float(r),
                             terminal=d,
                             env_info=env_info,
                             agent_info=agent_info)
                self.agent.update_context(s)
            if not d:
                self._prev_obs = next_o
                return False
        self._lengths.append(self._path_length)
        self._last_observations.append(self._prev_obs)
        return True
Пример #13
0
def test_step_type_dtype_mismatch_time_step(sample_data):
    with pytest.raises(ValueError, match='step_type must be dtype'):
        sample_data['step_type'] = []
        s = TimeStep(**sample_data)
        del s
Пример #14
0
def test_env_info_dtype_mismatch_time_step(sample_data):
    with pytest.raises(ValueError, match='env_info must be type'):
        sample_data['env_info'] = []
        s = TimeStep(**sample_data)
        del s
Пример #15
0
def test_reward_dtype_mismatch_time_step(sample_data):
    with pytest.raises(ValueError, match='reward must be type'):
        sample_data['reward'] = []
        s = TimeStep(**sample_data)
        del s
Пример #16
0
def test_terminal_dtype_mismatch_time_step(sample_data):
    with pytest.raises(ValueError, match='terminal must be dtype bool'):
        sample_data['terminal'] = []
        s = TimeStep(**sample_data)
        del s
Пример #17
0
def test_act_env_spec_mismatch_time_step(sample_data):
    with pytest.raises(ValueError,
                       match='action must conform to action_space'):
        sample_data['action'] = sample_data['action'][:-1]
        s = TimeStep(**sample_data)
        del s