Пример #1
0
    def test_batched(self):
        samples = [self.space.sample() for _ in range(10)]
        flattened = [
            su.flatten(self.space, su.torch_point(self.space, sample))
            for sample in samples
        ]
        stacked = torch.stack(flattened, dim=0)
        unflattened = su.unflatten(self.space, stacked)
        for bidx, refsample in enumerate(samples):
            # Compare each torch-ified sample to the corresponding unflattened from the stack
            assert self.same(su.torch_point(self.space, refsample),
                             unflattened, bidx)

        assert self.same(su.flatten(self.space, unflattened), stacked)
Пример #2
0
    def test_flatten(self):
        # We flatten Discrete to 1 value
        assert su.flatdim(self.space) == 25
        # gym flattens Discrete to one-hot
        assert gyms.flatdim(self.space) == 35

        asample = su.torch_point(self.space, self.space.sample())
        flattened = su.flatten(self.space, asample)
        unflattened = su.unflatten(self.space, flattened)
        assert self.same(asample, unflattened)

        # suppress `UserWarning: WARN: Box bound precision lowered by casting to float32`
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            flattened_space = su.flatten_space(self.space)
            assert flattened_space.shape == (25, )
            # The maximum comes from Discrete(11)
            assert flattened_space.high.max() == 11.0
            assert flattened_space.low.min() == -10.0

            gym_flattened_space = gyms.flatten_space(self.space)
            assert gym_flattened_space.shape == (35, )
            # The maximum comes from Box(-10, 10, (3, 4))
            assert gym_flattened_space.high.max() == 10.0
            assert gym_flattened_space.low.min() == -10.0
Пример #3
0
    def enforce(
        self,
        sample: Any,
        action_space: gym.spaces.Space,
        teacher: OrderedDict,
        teacher_force_info: Optional[Dict[str, Any]],
        action_name: Optional[str] = None,
    ):
        actions = su.flatten(action_space, sample)

        assert (
            len(actions.shape) == 3
        ), f"Got flattened actions with shape {actions.shape} (it should be [1 x `samplers` x `flatdims`])"

        if self.num_active_samplers is not None:
            assert actions.shape[1] == self.num_active_samplers

        expert_actions = su.flatten(action_space,
                                    teacher[Expert.ACTION_POLICY_LABEL])
        assert (
            expert_actions.shape == actions.shape
        ), f"expert actions shape {expert_actions.shape} doesn't match the model's {actions.shape}"

        # expert_success is 0 if the expert action could not be computed and otherwise equals 1.
        expert_action_exists_mask = teacher[Expert.EXPERT_SUCCESS_LABEL]

        if not self.always_enforce:
            teacher_forcing_mask = (torch.distributions.bernoulli.Bernoulli(
                torch.tensor(self.teacher_forcing(self.approx_steps))).sample(
                    expert_action_exists_mask.shape).long().to(
                        actions.device)) * expert_action_exists_mask
        else:
            teacher_forcing_mask = expert_action_exists_mask

        if teacher_force_info is not None:
            teacher_force_info["teacher_ratio/sampled{}".format(
                f"_{action_name}" if action_name is not None else "")] = (
                    teacher_forcing_mask.float().mean().item())

        extended_shape = teacher_forcing_mask.shape + (1, ) * (
            len(actions.shape) - len(teacher_forcing_mask.shape))

        actions = torch.where(teacher_forcing_mask.byte().view(extended_shape),
                              expert_actions, actions)

        return su.unflatten(action_space, actions)
Пример #4
0
 def flatten_output(self, unflattened):
     return (
         su.flatten(
             self.observation_space,
             su.torch_point(self.observation_space, unflattened),
         )
         .cpu()
         .numpy()
     )
Пример #5
0
    def test_tolist(self):
        space = gyms.MultiDiscrete([3, 3])
        actions = su.torch_point(space, space.sample())  # single sampler
        actions = actions.unsqueeze(0).unsqueeze(0)  # add [step, sampler]
        flat_actions = su.flatten(space, actions)
        al = su.action_list(space, flat_actions)
        assert len(al) == 1
        assert len(al[0]) == 2

        space = gyms.Tuple([gyms.MultiDiscrete([3, 3]), gyms.Discrete(2)])
        actions = su.torch_point(space, space.sample())  # single sampler
        actions = (
            actions[0].unsqueeze(0).unsqueeze(0),
            torch.tensor(actions[1]).unsqueeze(0).unsqueeze(0),
        )  # add [step, sampler]
        flat_actions = su.flatten(space, actions)
        al = su.action_list(space, flat_actions)
        assert len(al) == 1
        assert len(al[0][0]) == 2
        assert isinstance(al[0][1], int)

        space = gyms.Dict({
            "tuple": gyms.MultiDiscrete([3, 3]),
            "scalar": gyms.Discrete(2)
        })
        actions = su.torch_point(space, space.sample())  # single sampler
        actions = OrderedDict([
            ("tuple", actions["tuple"].unsqueeze(0).unsqueeze(0)),
            ("scalar",
             torch.tensor(actions["scalar"]).unsqueeze(0).unsqueeze(0)),
        ])
        flat_actions = su.flatten(space, actions)
        al = su.action_list(space, flat_actions)
        assert len(al) == 1
        assert len(al[0]["tuple"]) == 2
        assert isinstance(al[0]["scalar"], int)
Пример #6
0
    def get_observation(self, env: EnvType, task: SubTaskType, *args: Any,
                        **kwargs: Any) -> Any:
        # If the task is completed, we needn't (perhaps can't) find the expert
        # action from the (current) terminal state.
        if task.is_done():
            return self._zeroed_observation

        action, expert_was_successful = task.query_expert(**self.expert_args)

        if isinstance(action, int):
            assert isinstance(self.action_space, gym.spaces.Discrete)
            unflattened_action = action
        else:
            # Assume we receive a gym-flattened numpy action
            unflattened_action = gyms.unflatten(self.action_space, action)

        unflattened_torch = su.torch_point(
            self.unflattened_observation_space,
            (unflattened_action, expert_was_successful),
        )

        flattened_torch = su.flatten(self.unflattened_observation_space,
                                     unflattened_torch)
        return flattened_torch.cpu().numpy()