def split_forward_pass(self, observations: Observations) -> Tensor: """Perform a forward pass for a batch of observations from different tasks. This is called in `forward` when there is more than one unique task label in the batch. This will call `forward` for each task id present in the batch, passing it a slice of the batch, in which all items are from that task. NOTE: This cannot cause recursion problems, because `forward`(d=2) will be called with a bach of items, all of which come from the same task. This makes it so `split_forward_pass` cannot then be called again. Parameters ---------- observations : Observations Observations, in which the task labels might not all be the same. Returns ------- Tensor The outputs/logits from each task, re-assembled into a single batch, with the task ordering from `observations` preserved. """ assert observations.task_labels is not None assert self.hp.multihead, "Can only use split forward pass with multiple heads." # We have task labels. task_labels = observations.task_labels if isinstance(task_labels, Tensor): task_labels = task_labels.cpu().numpy() # Get the indices of the items from each task. all_task_indices_dict: Dict[int, np.ndarray] = get_task_indices(task_labels) if len(all_task_indices_dict) == 1: # No need to split the input, since everything is from the same task. task_id: int = task_labels[0].item() self.setup_for_task(task_id) return self.forward(observations) # Placeholder for the predicitons for each item in the batch. # NOTE: We put each item in the batch in this list and then stack the results. batch_size = len(task_labels) task_outputs: List[Batch] = [None for _ in range(batch_size)] for task_id, task_indices in all_task_indices_dict.items(): # Take a slice of the observations, in which all items come from this task. task_observations = get_slice(observations, task_indices) # Perform a "normal" forward pass (Base case). task_output = self.forward(task_observations) # Store the outputs for the items from this task in the list. for i, index in enumerate(task_indices): task_outputs[index] = get_slice(task_output, i) # Stack the results. assert all(item is not None for item in task_outputs) merged_outputs = concatenate(task_outputs) return merged_outputs
def output_head_loss(self, forward_pass: ForwardPass, actions: Actions, rewards: Rewards) -> Loss: # Asks each output head for its contribution to the loss. observations: IncrementalSetting.Observations = forward_pass.observations task_labels = observations.task_labels if isinstance(task_labels, Tensor): task_labels = task_labels.cpu().numpy() batch_size = forward_pass.batch_size assert batch_size is not None if task_labels is None: if self.task_inference_module: # TODO: Predict the task ids using some kind of task # inference mechanism. task_labels = self.task_inference_module(forward_pass) else: raise NotImplementedError( f"Multihead model doesn't have access to task labels and " f"doesn't have a task inference module!") # TODO: Maybe use the last trained output head, by default? # BUG: We get no loss from the output head for the first episode after a task # switch. # NOTE: The problem is that the `done` in the observation isn't necessarily # associated with the task designed by the `task_id` in that observation! # That is because of how vectorized environments work, they reset the env and # give the new initial observation when `done` is True, rather than the last # observation in that env. if self.previous_task_labels is None: self.previous_task_labels = task_labels # Default behaviour: use the (only) output head. if not self.hp.multihead: return self.output_head.get_loss( forward_pass, actions=actions, rewards=rewards, ) # The sum of all the losses from all the output heads. total_loss = Loss(self.output_head.name) task_switched_in_env = task_labels != self.previous_task_labels # TODO: This `done` attribute isn't added in supervised settings. episode_ended = getattr(observations, "done", np.zeros(batch_size, dtype=bool)) # TODO: Remove all this useless conversion from Tensors to ndarrays, by making # Sequoia more numpy-centric. if isinstance(episode_ended, Tensor): episode_ended = episode_ended.cpu().numpy() # logger.debug(f"Task labels: {task_labels}, task switched in env: {task_switched_in_env}, episode ended: {episode_ended}") done_set_to_false_temporarily_indices = [] if any(episode_ended & task_switched_in_env): # In the environments where there was a task switch to a different task and # where some episodes ended, we need to first get the corresponding output # head losses from these environments first. if self.batch_size in {None, 1}: # If the batch size is 1, this is a little bit simpler to deal with. previous_task: int = self.previous_task_labels[0].item() # IDEA: from sequoia.methods.models.output_heads.rl import PolicyHead previous_output_head = self.output_heads[str(previous_task)] assert isinstance( previous_output_head, PolicyHead ), "todo: assuming that this only happends in RL currently." # We want the loss from that output head, but we don't want to # re-compute it below! env_index_in_previous_batch = 0 # breakpoint() logger.debug( f"Getting a loss from the output head for task {previous_task}, that was used for the last task." ) env_episode_loss = previous_output_head.get_episode_loss( env_index_in_previous_batch, done=True) # logger.debug(f"Loss from that output head: {env_episode_loss}") # Add this end-of-episode loss to the total loss. # breakpoint() # BUG: This can sometimes (rarely) be None! Need to better understand # why this is happening. if env_episode_loss is None: logger.warning( RuntimeWarning( f"BUG: Env {env_index_in_previous_batch} gave back a loss " f"of `None`, when we expected a loss from that output head " f"for task id {previous_task}.")) else: total_loss += env_episode_loss # We call on_episode_end so the output head can clear the relevant # buffers. Note that get_episode_loss(env_index, done=True) doesn't # clear the buffers, it just calculates a loss. previous_output_head.on_episode_end( env_index_in_previous_batch) # Set `done` to `False` for that env, to prevent the output head for the # new task from seeing the first observation in the episode as the last. observations.done[env_index_in_previous_batch] = False # FIXME: If we modify that entry in-place, then even after this method # returns, the change will persist.. Therefore we just save the indices # that we altered, and reset them before returning. done_set_to_false_temporarily_indices.append( env_index_in_previous_batch) else: raise NotImplementedError( "TODO: The BaselineModel doesn't yet support having multiple " "different tasks within the same batch in RL. ") # IDEA: Need to somehow pass the indices of which env to take care of to # each output head, so they can create / clear buffers only when needed. assert task_labels is not None all_task_indices: Dict[int, Tensor] = get_task_indices(task_labels) # Get the loss from each output head: if len(all_task_indices) == 1: # If everything is in the same task (only one key), no need to split/merge # stuff, so it's a bit easier: task_id: int = list(all_task_indices.keys())[0] with self.switch_output_head(task_id): # task_output_head = self.output_heads[str(task_id)] total_loss += self.output_head.get_loss( forward_pass, actions=actions, rewards=rewards, ) else: # Split off the input batch, do a forward pass for each sub-task. # (could be done in parallel but whatever.) # TODO: Also, not sure if this will play well with DP, DDP, etc. for task_id, task_indices in all_task_indices.items(): # # Make a partial observation without the task labels, so that # # super().forward will use the current output head. forward_pass_slice = get_slice(forward_pass, task_indices) actions_slice = get_slice(actions, task_indices) rewards_slice = get_slice(rewards, task_indices) logger.debug( f"Getting output head loss" f"{len(task_indices)/batch_size:.0%} of the batch which " f"has task_id of '{task_id}'.") task_output_head = self.output_heads[str(task_id)] task_loss = task_output_head.get_loss( forward_pass_slice, actions=actions_slice, rewards=rewards_slice, ) # FIXME: debugging # task_output_head_loss.name += f"(task {task_id})" logger.debug(f"Task {task_id} loss: {task_loss}") total_loss += task_loss self.previous_task_labels = task_labels # FIXME: Reset the 'done' to True, if we manually set it to False. for index in done_set_to_false_temporarily_indices: observations.done[index] = True return total_loss