Ejemplo n.º 1
0
    def split_forward_pass(self, observations: Observations) -> Tensor:
        """Perform a forward pass for a batch of observations from different tasks.

        This is called in `forward` when there is more than one unique task label in the
        batch.
        This will call `forward` for each task id present in the batch, passing it a
        slice of the batch, in which all items are from that task.

        NOTE: This cannot cause recursion problems, because `forward`(d=2) will be
        called with a bach of items, all of which come from the same task. This makes it
        so `split_forward_pass` cannot then be called again.

        Parameters
        ----------
        observations : Observations
            Observations, in which the task labels might not all be the same.

        Returns
        -------
        Tensor
            The outputs/logits from each task, re-assembled into a single batch, with
            the task ordering from `observations` preserved.
        """
        assert observations.task_labels is not None
        assert self.hp.multihead, "Can only use split forward pass with multiple heads."
        # We have task labels.
        task_labels = observations.task_labels
        if isinstance(task_labels, Tensor):
            task_labels = task_labels.cpu().numpy()

        # Get the indices of the items from each task.
        all_task_indices_dict: Dict[int,
                                    np.ndarray] = get_task_indices(task_labels)

        if len(all_task_indices_dict) == 1:
            # No need to split the input, since everything is from the same task.
            task_id: int = task_labels[0].item()
            self.setup_for_task(task_id)
            return self.forward(observations)

        # Placeholder for the predicitons for each item in the batch.
        # NOTE: We put each item in the batch in this list and then stack the results.
        batch_size = len(task_labels)
        task_outputs: List[Batch] = [None for _ in range(batch_size)]

        for task_id, task_indices in all_task_indices_dict.items():
            # Take a slice of the observations, in which all items come from this task.
            task_observations = get_slice(observations, task_indices)
            # Perform a "normal" forward pass (Base case).
            task_output = self.forward(task_observations)

            # Store the outputs for the items from this task in the list.
            for i, index in enumerate(task_indices):
                task_outputs[index] = get_slice(task_output, i)

        # Stack the results.
        assert all(item is not None for item in task_outputs)
        merged_outputs = concatenate(task_outputs)
        return merged_outputs
Ejemplo n.º 2
0
 def concatenate(cls: Type[B], items: List[B], **kwargs) -> B:
     items = list(items)
     from sequoia.utils.generic_functions import concatenate
     assert isinstance(items[0], cls)
     return concatenate(items, **kwargs)
Ejemplo n.º 3
0
    def __add__(self, other):
        from sequoia.utils.generic_functions import concatenate

        return concatenate(self, other)