def test_batched(self): samples = [self.space.sample() for _ in range(10)] flattened = [ su.flatten(self.space, su.torch_point(self.space, sample)) for sample in samples ] stacked = torch.stack(flattened, dim=0) unflattened = su.unflatten(self.space, stacked) for bidx, refsample in enumerate(samples): # Compare each torch-ified sample to the corresponding unflattened from the stack assert self.same(su.torch_point(self.space, refsample), unflattened, bidx) assert self.same(su.flatten(self.space, unflattened), stacked)
def test_flatten(self): # We flatten Discrete to 1 value assert su.flatdim(self.space) == 25 # gym flattens Discrete to one-hot assert gyms.flatdim(self.space) == 35 asample = su.torch_point(self.space, self.space.sample()) flattened = su.flatten(self.space, asample) unflattened = su.unflatten(self.space, flattened) assert self.same(asample, unflattened) # suppress `UserWarning: WARN: Box bound precision lowered by casting to float32` with warnings.catch_warnings(): warnings.simplefilter("ignore") flattened_space = su.flatten_space(self.space) assert flattened_space.shape == (25, ) # The maximum comes from Discrete(11) assert flattened_space.high.max() == 11.0 assert flattened_space.low.min() == -10.0 gym_flattened_space = gyms.flatten_space(self.space) assert gym_flattened_space.shape == (35, ) # The maximum comes from Box(-10, 10, (3, 4)) assert gym_flattened_space.high.max() == 10.0 assert gym_flattened_space.low.min() == -10.0
def test_conversion(self): gsample = self.space.sample() asample = su.torch_point(self.space, gsample) back = su.numpy_point(self.space, asample) assert self.same(back, gsample)
def flatten_output(self, unflattened): return ( su.flatten( self.observation_space, su.torch_point(self.observation_space, unflattened), ) .cpu() .numpy() )
def _zeroed_observation(self) -> Union[OrderedDict, Tuple]: # AllenAct-style flattened space (to easily generate an all-zeroes action as an array) flat_space = su.flatten_space(self.observation_space) # torch point to correctly unflatten `Discrete` for zeroed output flat_zeroed = su.torch_point(flat_space, np.zeros_like(flat_space.sample())) # unflatten zeroed output and convert to numpy return su.numpy_point( self.observation_space, su.unflatten(self.observation_space, flat_zeroed) )
def test_tolist(self): space = gyms.MultiDiscrete([3, 3]) actions = su.torch_point(space, space.sample()) # single sampler actions = actions.unsqueeze(0).unsqueeze(0) # add [step, sampler] flat_actions = su.flatten(space, actions) al = su.action_list(space, flat_actions) assert len(al) == 1 assert len(al[0]) == 2 space = gyms.Tuple([gyms.MultiDiscrete([3, 3]), gyms.Discrete(2)]) actions = su.torch_point(space, space.sample()) # single sampler actions = ( actions[0].unsqueeze(0).unsqueeze(0), torch.tensor(actions[1]).unsqueeze(0).unsqueeze(0), ) # add [step, sampler] flat_actions = su.flatten(space, actions) al = su.action_list(space, flat_actions) assert len(al) == 1 assert len(al[0][0]) == 2 assert isinstance(al[0][1], int) space = gyms.Dict({ "tuple": gyms.MultiDiscrete([3, 3]), "scalar": gyms.Discrete(2) }) actions = su.torch_point(space, space.sample()) # single sampler actions = OrderedDict([ ("tuple", actions["tuple"].unsqueeze(0).unsqueeze(0)), ("scalar", torch.tensor(actions["scalar"]).unsqueeze(0).unsqueeze(0)), ]) flat_actions = su.flatten(space, actions) al = su.action_list(space, flat_actions) assert len(al) == 1 assert len(al[0]["tuple"]) == 2 assert isinstance(al[0]["scalar"], int)
def get_observation(self, env: EnvType, task: SubTaskType, *args: Any, **kwargs: Any) -> Any: # If the task is completed, we needn't (perhaps can't) find the expert # action from the (current) terminal state. if task.is_done(): return self._zeroed_observation action, expert_was_successful = task.query_expert(**self.expert_args) if isinstance(action, int): assert isinstance(self.action_space, gym.spaces.Discrete) unflattened_action = action else: # Assume we receive a gym-flattened numpy action unflattened_action = gyms.unflatten(self.action_space, action) unflattened_torch = su.torch_point( self.unflattened_observation_space, (unflattened_action, expert_was_successful), ) flattened_torch = su.flatten(self.unflattened_observation_space, unflattened_torch) return flattened_torch.cpu().numpy()