Beispiel #1
0
    def test_batched(self):
        samples = [self.space.sample() for _ in range(10)]
        flattened = [
            su.flatten(self.space, su.torch_point(self.space, sample))
            for sample in samples
        ]
        stacked = torch.stack(flattened, dim=0)
        unflattened = su.unflatten(self.space, stacked)
        for bidx, refsample in enumerate(samples):
            # Compare each torch-ified sample to the corresponding unflattened from the stack
            assert self.same(su.torch_point(self.space, refsample),
                             unflattened, bidx)

        assert self.same(su.flatten(self.space, unflattened), stacked)
Beispiel #2
0
    def test_flatten(self):
        # We flatten Discrete to 1 value
        assert su.flatdim(self.space) == 25
        # gym flattens Discrete to one-hot
        assert gyms.flatdim(self.space) == 35

        asample = su.torch_point(self.space, self.space.sample())
        flattened = su.flatten(self.space, asample)
        unflattened = su.unflatten(self.space, flattened)
        assert self.same(asample, unflattened)

        # suppress `UserWarning: WARN: Box bound precision lowered by casting to float32`
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            flattened_space = su.flatten_space(self.space)
            assert flattened_space.shape == (25, )
            # The maximum comes from Discrete(11)
            assert flattened_space.high.max() == 11.0
            assert flattened_space.low.min() == -10.0

            gym_flattened_space = gyms.flatten_space(self.space)
            assert gym_flattened_space.shape == (35, )
            # The maximum comes from Box(-10, 10, (3, 4))
            assert gym_flattened_space.high.max() == 10.0
            assert gym_flattened_space.low.min() == -10.0
Beispiel #3
0
    def test_conversion(self):
        gsample = self.space.sample()

        asample = su.torch_point(self.space, gsample)

        back = su.numpy_point(self.space, asample)

        assert self.same(back, gsample)
Beispiel #4
0
 def flatten_output(self, unflattened):
     return (
         su.flatten(
             self.observation_space,
             su.torch_point(self.observation_space, unflattened),
         )
         .cpu()
         .numpy()
     )
Beispiel #5
0
 def _zeroed_observation(self) -> Union[OrderedDict, Tuple]:
     # AllenAct-style flattened space (to easily generate an all-zeroes action as an array)
     flat_space = su.flatten_space(self.observation_space)
     # torch point to correctly unflatten `Discrete` for zeroed output
     flat_zeroed = su.torch_point(flat_space, np.zeros_like(flat_space.sample()))
     # unflatten zeroed output and convert to numpy
     return su.numpy_point(
         self.observation_space, su.unflatten(self.observation_space, flat_zeroed)
     )
Beispiel #6
0
    def test_tolist(self):
        space = gyms.MultiDiscrete([3, 3])
        actions = su.torch_point(space, space.sample())  # single sampler
        actions = actions.unsqueeze(0).unsqueeze(0)  # add [step, sampler]
        flat_actions = su.flatten(space, actions)
        al = su.action_list(space, flat_actions)
        assert len(al) == 1
        assert len(al[0]) == 2

        space = gyms.Tuple([gyms.MultiDiscrete([3, 3]), gyms.Discrete(2)])
        actions = su.torch_point(space, space.sample())  # single sampler
        actions = (
            actions[0].unsqueeze(0).unsqueeze(0),
            torch.tensor(actions[1]).unsqueeze(0).unsqueeze(0),
        )  # add [step, sampler]
        flat_actions = su.flatten(space, actions)
        al = su.action_list(space, flat_actions)
        assert len(al) == 1
        assert len(al[0][0]) == 2
        assert isinstance(al[0][1], int)

        space = gyms.Dict({
            "tuple": gyms.MultiDiscrete([3, 3]),
            "scalar": gyms.Discrete(2)
        })
        actions = su.torch_point(space, space.sample())  # single sampler
        actions = OrderedDict([
            ("tuple", actions["tuple"].unsqueeze(0).unsqueeze(0)),
            ("scalar",
             torch.tensor(actions["scalar"]).unsqueeze(0).unsqueeze(0)),
        ])
        flat_actions = su.flatten(space, actions)
        al = su.action_list(space, flat_actions)
        assert len(al) == 1
        assert len(al[0]["tuple"]) == 2
        assert isinstance(al[0]["scalar"], int)
Beispiel #7
0
    def get_observation(self, env: EnvType, task: SubTaskType, *args: Any,
                        **kwargs: Any) -> Any:
        # If the task is completed, we needn't (perhaps can't) find the expert
        # action from the (current) terminal state.
        if task.is_done():
            return self._zeroed_observation

        action, expert_was_successful = task.query_expert(**self.expert_args)

        if isinstance(action, int):
            assert isinstance(self.action_space, gym.spaces.Discrete)
            unflattened_action = action
        else:
            # Assume we receive a gym-flattened numpy action
            unflattened_action = gyms.unflatten(self.action_space, action)

        unflattened_torch = su.torch_point(
            self.unflattened_observation_space,
            (unflattened_action, expert_was_successful),
        )

        flattened_torch = su.flatten(self.unflattened_observation_space,
                                     unflattened_torch)
        return flattened_torch.cpu().numpy()