def test_batched(self): samples = [self.space.sample() for _ in range(10)] flattened = [ su.flatten(self.space, su.torch_point(self.space, sample)) for sample in samples ] stacked = torch.stack(flattened, dim=0) unflattened = su.unflatten(self.space, stacked) for bidx, refsample in enumerate(samples): # Compare each torch-ified sample to the corresponding unflattened from the stack assert self.same(su.torch_point(self.space, refsample), unflattened, bidx) assert self.same(su.flatten(self.space, unflattened), stacked)
def test_flatten(self): # We flatten Discrete to 1 value assert su.flatdim(self.space) == 25 # gym flattens Discrete to one-hot assert gyms.flatdim(self.space) == 35 asample = su.torch_point(self.space, self.space.sample()) flattened = su.flatten(self.space, asample) unflattened = su.unflatten(self.space, flattened) assert self.same(asample, unflattened) # suppress `UserWarning: WARN: Box bound precision lowered by casting to float32` with warnings.catch_warnings(): warnings.simplefilter("ignore") flattened_space = su.flatten_space(self.space) assert flattened_space.shape == (25, ) # The maximum comes from Discrete(11) assert flattened_space.high.max() == 11.0 assert flattened_space.low.min() == -10.0 gym_flattened_space = gyms.flatten_space(self.space) assert gym_flattened_space.shape == (35, ) # The maximum comes from Box(-10, 10, (3, 4)) assert gym_flattened_space.high.max() == 10.0 assert gym_flattened_space.low.min() == -10.0
def enforce( self, sample: Any, action_space: gym.spaces.Space, teacher: OrderedDict, teacher_force_info: Optional[Dict[str, Any]], action_name: Optional[str] = None, ): actions = su.flatten(action_space, sample) assert ( len(actions.shape) == 3 ), f"Got flattened actions with shape {actions.shape} (it should be [1 x `samplers` x `flatdims`])" if self.num_active_samplers is not None: assert actions.shape[1] == self.num_active_samplers expert_actions = su.flatten(action_space, teacher[Expert.ACTION_POLICY_LABEL]) assert ( expert_actions.shape == actions.shape ), f"expert actions shape {expert_actions.shape} doesn't match the model's {actions.shape}" # expert_success is 0 if the expert action could not be computed and otherwise equals 1. expert_action_exists_mask = teacher[Expert.EXPERT_SUCCESS_LABEL] if not self.always_enforce: teacher_forcing_mask = (torch.distributions.bernoulli.Bernoulli( torch.tensor(self.teacher_forcing(self.approx_steps))).sample( expert_action_exists_mask.shape).long().to( actions.device)) * expert_action_exists_mask else: teacher_forcing_mask = expert_action_exists_mask if teacher_force_info is not None: teacher_force_info["teacher_ratio/sampled{}".format( f"_{action_name}" if action_name is not None else "")] = ( teacher_forcing_mask.float().mean().item()) extended_shape = teacher_forcing_mask.shape + (1, ) * ( len(actions.shape) - len(teacher_forcing_mask.shape)) actions = torch.where(teacher_forcing_mask.byte().view(extended_shape), expert_actions, actions) return su.unflatten(action_space, actions)
def flatten_output(self, unflattened): return ( su.flatten( self.observation_space, su.torch_point(self.observation_space, unflattened), ) .cpu() .numpy() )
def test_tolist(self): space = gyms.MultiDiscrete([3, 3]) actions = su.torch_point(space, space.sample()) # single sampler actions = actions.unsqueeze(0).unsqueeze(0) # add [step, sampler] flat_actions = su.flatten(space, actions) al = su.action_list(space, flat_actions) assert len(al) == 1 assert len(al[0]) == 2 space = gyms.Tuple([gyms.MultiDiscrete([3, 3]), gyms.Discrete(2)]) actions = su.torch_point(space, space.sample()) # single sampler actions = ( actions[0].unsqueeze(0).unsqueeze(0), torch.tensor(actions[1]).unsqueeze(0).unsqueeze(0), ) # add [step, sampler] flat_actions = su.flatten(space, actions) al = su.action_list(space, flat_actions) assert len(al) == 1 assert len(al[0][0]) == 2 assert isinstance(al[0][1], int) space = gyms.Dict({ "tuple": gyms.MultiDiscrete([3, 3]), "scalar": gyms.Discrete(2) }) actions = su.torch_point(space, space.sample()) # single sampler actions = OrderedDict([ ("tuple", actions["tuple"].unsqueeze(0).unsqueeze(0)), ("scalar", torch.tensor(actions["scalar"]).unsqueeze(0).unsqueeze(0)), ]) flat_actions = su.flatten(space, actions) al = su.action_list(space, flat_actions) assert len(al) == 1 assert len(al[0]["tuple"]) == 2 assert isinstance(al[0]["scalar"], int)
def get_observation(self, env: EnvType, task: SubTaskType, *args: Any, **kwargs: Any) -> Any: # If the task is completed, we needn't (perhaps can't) find the expert # action from the (current) terminal state. if task.is_done(): return self._zeroed_observation action, expert_was_successful = task.query_expert(**self.expert_args) if isinstance(action, int): assert isinstance(self.action_space, gym.spaces.Discrete) unflattened_action = action else: # Assume we receive a gym-flattened numpy action unflattened_action = gyms.unflatten(self.action_space, action) unflattened_torch = su.torch_point( self.unflattened_observation_space, (unflattened_action, expert_was_successful), ) flattened_torch = su.flatten(self.unflattened_observation_space, unflattened_torch) return flattened_torch.cpu().numpy()