def test_flatten(self): # We flatten Discrete to 1 value assert su.flatdim(self.space) == 25 # gym flattens Discrete to one-hot assert gyms.flatdim(self.space) == 35 asample = su.torch_point(self.space, self.space.sample()) flattened = su.flatten(self.space, asample) unflattened = su.unflatten(self.space, flattened) assert self.same(asample, unflattened) # suppress `UserWarning: WARN: Box bound precision lowered by casting to float32` with warnings.catch_warnings(): warnings.simplefilter("ignore") flattened_space = su.flatten_space(self.space) assert flattened_space.shape == (25, ) # The maximum comes from Discrete(11) assert flattened_space.high.max() == 11.0 assert flattened_space.low.min() == -10.0 gym_flattened_space = gyms.flatten_space(self.space) assert gym_flattened_space.shape == (35, ) # The maximum comes from Box(-10, 10, (3, 4)) assert gym_flattened_space.high.max() == 10.0 assert gym_flattened_space.low.min() == -10.0
def __init__( self, distr: Distr, obs: Dict[str, Any], action_space: gym.spaces.Space, num_active_samplers: Optional[int], approx_steps: Optional[int], teacher_forcing: Optional[TeacherForcingAnnealingType], tracking_info: Optional[Dict[str, Any]], always_enforce: bool = False, ): self.distr = distr self.is_sequential = isinstance(self.distr, SequentialDistr) # action_space is a gym.spaces.Dict for SequentialDistr, or any gym.Space for other Distr self.action_space = action_space self.num_active_samplers = num_active_samplers self.approx_steps = approx_steps self.teacher_forcing = teacher_forcing self.tracking_info = tracking_info self.always_enforce = always_enforce assert ( "expert_action" in obs ), "When using teacher forcing, obs must contain an `expert_action` uuid" obs_space = Expert.flagged_space(self.action_space, use_dict_as_groups=self.is_sequential) self.expert = su.unflatten(obs_space, obs["expert_action"])
def _zeroed_observation(self) -> Union[OrderedDict, Tuple]: # AllenAct-style flattened space (to easily generate an all-zeroes action as an array) flat_space = su.flatten_space(self.observation_space) # torch point to correctly unflatten `Discrete` for zeroed output flat_zeroed = su.torch_point(flat_space, np.zeros_like(flat_space.sample())) # unflatten zeroed output and convert to numpy return su.numpy_point( self.observation_space, su.unflatten(self.observation_space, flat_zeroed) )
def test_batched(self): samples = [self.space.sample() for _ in range(10)] flattened = [ su.flatten(self.space, su.torch_point(self.space, sample)) for sample in samples ] stacked = torch.stack(flattened, dim=0) unflattened = su.unflatten(self.space, stacked) for bidx, refsample in enumerate(samples): # Compare each torch-ified sample to the corresponding unflattened from the stack assert self.same(su.torch_point(self.space, refsample), unflattened, bidx) assert self.same(su.flatten(self.space, unflattened), stacked)
def enforce( self, sample: Any, action_space: gym.spaces.Space, teacher: OrderedDict, teacher_force_info: Optional[Dict[str, Any]], action_name: Optional[str] = None, ): actions = su.flatten(action_space, sample) assert ( len(actions.shape) == 3 ), f"Got flattened actions with shape {actions.shape} (it should be [1 x `samplers` x `flatdims`])" if self.num_active_samplers is not None: assert actions.shape[1] == self.num_active_samplers expert_actions = su.flatten(action_space, teacher[Expert.ACTION_POLICY_LABEL]) assert ( expert_actions.shape == actions.shape ), f"expert actions shape {expert_actions.shape} doesn't match the model's {actions.shape}" # expert_success is 0 if the expert action could not be computed and otherwise equals 1. expert_action_exists_mask = teacher[Expert.EXPERT_SUCCESS_LABEL] if not self.always_enforce: teacher_forcing_mask = (torch.distributions.bernoulli.Bernoulli( torch.tensor(self.teacher_forcing(self.approx_steps))).sample( expert_action_exists_mask.shape).long().to( actions.device)) * expert_action_exists_mask else: teacher_forcing_mask = expert_action_exists_mask if teacher_force_info is not None: teacher_force_info["teacher_ratio/sampled{}".format( f"_{action_name}" if action_name is not None else "")] = ( teacher_forcing_mask.float().mean().item()) extended_shape = teacher_forcing_mask.shape + (1, ) * ( len(actions.shape) - len(teacher_forcing_mask.shape)) actions = torch.where(teacher_forcing_mask.byte().view(extended_shape), expert_actions, actions) return su.unflatten(action_space, actions)
def loss( # type: ignore self, step_count: int, batch: ObservationType, actor_critic_output: ActorCriticOutput[Distr], *args, **kwargs, ): """Computes the imitation loss. # Parameters batch : A batch of data corresponding to the information collected when rolling out (possibly many) agents over a fixed number of steps. In particular this batch should have the same format as that returned by `RolloutStorage.recurrent_generator`. Here `batch["observations"]` must contain `"expert_action"` observations or `"expert_policy"` observations. See `ExpertActionSensor` (or `ExpertPolicySensor`) for an example of a sensor producing such observations. actor_critic_output : The output of calling an ActorCriticModel on the observations in `batch`. args : Extra args. Ignored. kwargs : Extra kwargs. Ignored. # Returns A (0-dimensional) torch.FloatTensor corresponding to the computed loss. `.backward()` will be called on this tensor in order to compute a gradient update to the ActorCriticModel's parameters. """ observations = cast(Dict[str, torch.Tensor], batch["observations"]) losses = OrderedDict() should_report_loss = False if "expert_action" in observations: if self.expert_sensor is None or not self.expert_sensor.use_groups: expert_actions_and_mask = observations["expert_action"] assert expert_actions_and_mask.shape[-1] == 2 expert_actions_and_mask_reshaped = expert_actions_and_mask.view(-1, 2) expert_actions = expert_actions_and_mask_reshaped[:, 0].view( *expert_actions_and_mask.shape[:-1], 1 ) expert_actions_masks = ( expert_actions_and_mask_reshaped[:, 1] .float() .view(*expert_actions_and_mask.shape[:-1], 1) ) total_loss, expert_successes = self.group_loss( cast(CategoricalDistr, actor_critic_output.distributions), expert_actions, expert_actions_masks, ) should_report_loss = expert_successes.item() != 0 else: expert_actions = su.unflatten( self.expert_sensor.observation_space, observations["expert_action"] ) total_loss = 0 ready_actions = OrderedDict() for group_name, cd in zip( self.expert_sensor.group_spaces, cast( SequentialDistr, actor_critic_output.distributions ).conditional_distrs, ): assert group_name == cd.action_group_name cd.reset() cd.condition_on_input(**ready_actions) expert_action = expert_actions[group_name][ AbstractExpertSensor.ACTION_POLICY_LABEL ] expert_action_masks = expert_actions[group_name][ AbstractExpertSensor.EXPERT_SUCCESS_LABEL ] ready_actions[group_name] = expert_action current_loss, expert_successes = self.group_loss( cd, expert_action, expert_action_masks, ) should_report_loss = ( expert_successes.item() != 0 or should_report_loss ) cd.reset() if expert_successes.item() != 0: losses[group_name + "_cross_entropy"] = current_loss.item() total_loss = total_loss + current_loss elif "expert_policy" in observations: if self.expert_sensor is None or not self.expert_sensor.use_groups: assert isinstance( actor_critic_output.distributions, CategoricalDistr ), "This implementation currently only supports `CategoricalDistr`" expert_policies = cast(Dict[str, torch.Tensor], batch["observations"])[ "expert_policy" ][..., :-1] expert_actions_masks = cast( Dict[str, torch.Tensor], batch["observations"] )["expert_policy"][..., -1:] expert_successes = expert_actions_masks.sum() if expert_successes.item() > 0: should_report_loss = True log_probs = cast( CategoricalDistr, actor_critic_output.distributions ).log_probs_tensor # Add dimensions to `expert_actions_masks` on the right to allow for masking # if necessary. len_diff = len(log_probs.shape) - len(expert_actions_masks.shape) assert len_diff >= 0 expert_actions_masks = expert_actions_masks.view( *expert_actions_masks.shape, *((1,) * len_diff) ) total_loss = ( -(log_probs * expert_policies) * expert_actions_masks ).sum() / torch.clamp(expert_successes, min=1) else: raise NotImplementedError( "This implementation currently only supports `CategoricalDistr`" ) else: raise NotImplementedError( "Imitation loss requires either `expert_action` or `expert_policy`" " sensor to be active." ) return ( total_loss, {"expert_cross_entropy": total_loss.item(), **losses} if should_report_loss else {}, )
def pick_prev_actions_step(self, step: int) -> ActionType: return su.unflatten(self.action_space, self.prev_actions[step:step + 1])
def recurrent_generator( self, advantages: torch.Tensor, adv_mean: torch.Tensor, adv_std: torch.Tensor, num_mini_batch: int, ): normalized_advantages = (advantages - adv_mean) / (adv_std + 1e-5) num_samplers = self.rewards.shape[1] assert num_samplers >= num_mini_batch, ( "The number of task samplers ({}) " "must be greater than or equal to the number of " "mini batches ({}).".format(num_samplers, num_mini_batch)) inds = np.round( np.linspace(0, num_samplers, num_mini_batch + 1, endpoint=True)).astype(np.int32) pairs = list(zip(inds[:-1], inds[1:])) random.shuffle(pairs) for start_ind, end_ind in pairs: cur_samplers = list(range(start_ind, end_ind)) memory_batch = self.memory.step_squeeze(0).sampler_select( cur_samplers) observations_batch = self.unflatten_observations( self.observations.slice(dim=0, stop=-1).sampler_select(cur_samplers)) actions_batch = [] prev_actions_batch = [] value_preds_batch = [] return_batch = [] masks_batch = [] old_action_log_probs_batch = [] adv_targ = [] norm_adv_targ = [] for ind in cur_samplers: actions_batch.append(self.actions[:, ind]) prev_actions_batch.append(self.prev_actions[:-1, ind]) value_preds_batch.append(self.value_preds[:-1, ind]) return_batch.append(self.returns[:-1, ind]) masks_batch.append(self.masks[:-1, ind]) old_action_log_probs_batch.append(self.action_log_probs[:, ind]) adv_targ.append(advantages[:, ind]) norm_adv_targ.append(normalized_advantages[:, ind]) actions_batch = torch.stack(actions_batch, 1) # type:ignore prev_actions_batch = torch.stack(prev_actions_batch, 1) # type:ignore value_preds_batch = torch.stack(value_preds_batch, 1) # type:ignore return_batch = torch.stack(return_batch, 1) # type:ignore masks_batch = torch.stack(masks_batch, 1) # type:ignore old_action_log_probs_batch = torch.stack( # type:ignore old_action_log_probs_batch, 1) adv_targ = torch.stack(adv_targ, 1) # type:ignore norm_adv_targ = torch.stack(norm_adv_targ, 1) # type:ignore yield { "observations": observations_batch, "memory": memory_batch, "actions": su.unflatten(self.action_space, actions_batch), "prev_actions": su.unflatten(self.action_space, prev_actions_batch), "values": value_preds_batch, "returns": return_batch, "masks": masks_batch, "old_action_log_probs": old_action_log_probs_batch, "adv_targ": adv_targ, "norm_adv_targ": norm_adv_targ, }