Beispiel #1
0
    def split_forward_pass(self, observations: Observations) -> Tensor:
        """Perform a forward pass for a batch of observations from different tasks.

        This is called in `forward` when there is more than one unique task label in the
        batch.
        This will call `forward` for each task id present in the batch, passing it a
        slice of the batch, in which all items are from that task.

        NOTE: This cannot cause recursion problems, because `forward`(d=2) will be
        called with a bach of items, all of which come from the same task. This makes it
        so `split_forward_pass` cannot then be called again.

        Parameters
        ----------
        observations : Observations
            Observations, in which the task labels might not all be the same.

        Returns
        -------
        Tensor
            The outputs/logits from each task, re-assembled into a single batch, with
            the task ordering from `observations` preserved.
        """
        assert observations.task_labels is not None
        assert self.hp.multihead, "Can only use split forward pass with multiple heads."
        # We have task labels.
        task_labels = observations.task_labels
        if isinstance(task_labels, Tensor):
            task_labels = task_labels.cpu().numpy()

        # Get the indices of the items from each task.
        all_task_indices_dict: Dict[int,
                                    np.ndarray] = get_task_indices(task_labels)

        if len(all_task_indices_dict) == 1:
            # No need to split the input, since everything is from the same task.
            task_id: int = task_labels[0].item()
            self.setup_for_task(task_id)
            return self.forward(observations)

        # Placeholder for the predicitons for each item in the batch.
        # NOTE: We put each item in the batch in this list and then stack the results.
        batch_size = len(task_labels)
        task_outputs: List[Batch] = [None for _ in range(batch_size)]

        for task_id, task_indices in all_task_indices_dict.items():
            # Take a slice of the observations, in which all items come from this task.
            task_observations = get_slice(observations, task_indices)
            # Perform a "normal" forward pass (Base case).
            task_output = self.forward(task_observations)

            # Store the outputs for the items from this task in the list.
            for i, index in enumerate(task_indices):
                task_outputs[index] = get_slice(task_output, i)

        # Stack the results.
        assert all(item is not None for item in task_outputs)
        merged_outputs = concatenate(task_outputs)
        return merged_outputs
    def output_head_loss(self, forward_pass: ForwardPass, actions: Actions,
                         rewards: Rewards) -> Loss:
        # Asks each output head for its contribution to the loss.
        observations: IncrementalSetting.Observations = forward_pass.observations
        task_labels = observations.task_labels
        if isinstance(task_labels, Tensor):
            task_labels = task_labels.cpu().numpy()

        batch_size = forward_pass.batch_size
        assert batch_size is not None

        if task_labels is None:
            if self.task_inference_module:
                # TODO: Predict the task ids using some kind of task
                # inference mechanism.
                task_labels = self.task_inference_module(forward_pass)
            else:
                raise NotImplementedError(
                    f"Multihead model doesn't have access to task labels and "
                    f"doesn't have a task inference module!")
                # TODO: Maybe use the last trained output head, by default?
        # BUG: We get no loss from the output head for the first episode after a task
        # switch.
        # NOTE: The problem is that the `done` in the observation isn't necessarily
        # associated with the task designed by the `task_id` in that observation!
        # That is because of how vectorized environments work, they reset the env and
        # give the new initial observation when `done` is True, rather than the last
        # observation in that env.
        if self.previous_task_labels is None:
            self.previous_task_labels = task_labels

        # Default behaviour: use the (only) output head.
        if not self.hp.multihead:
            return self.output_head.get_loss(
                forward_pass,
                actions=actions,
                rewards=rewards,
            )

        # The sum of all the losses from all the output heads.
        total_loss = Loss(self.output_head.name)

        task_switched_in_env = task_labels != self.previous_task_labels
        # TODO: This `done` attribute isn't added in supervised settings.
        episode_ended = getattr(observations, "done",
                                np.zeros(batch_size, dtype=bool))
        # TODO: Remove all this useless conversion from Tensors to ndarrays, by making
        # Sequoia more numpy-centric.
        if isinstance(episode_ended, Tensor):
            episode_ended = episode_ended.cpu().numpy()

        # logger.debug(f"Task labels: {task_labels}, task switched in env: {task_switched_in_env}, episode ended: {episode_ended}")
        done_set_to_false_temporarily_indices = []

        if any(episode_ended & task_switched_in_env):
            # In the environments where there was a task switch to a different task and
            # where some episodes ended, we need to first get the corresponding output
            # head losses from these environments first.
            if self.batch_size in {None, 1}:
                # If the batch size is 1, this is a little bit simpler to deal with.
                previous_task: int = self.previous_task_labels[0].item()
                # IDEA:
                from sequoia.methods.models.output_heads.rl import PolicyHead

                previous_output_head = self.output_heads[str(previous_task)]
                assert isinstance(
                    previous_output_head, PolicyHead
                ), "todo: assuming that this only happends in RL currently."
                # We want the loss from that output head, but we don't want to
                # re-compute it below!
                env_index_in_previous_batch = 0
                # breakpoint()
                logger.debug(
                    f"Getting a loss from the output head for task {previous_task}, that was used for the last task."
                )
                env_episode_loss = previous_output_head.get_episode_loss(
                    env_index_in_previous_batch, done=True)
                # logger.debug(f"Loss from that output head: {env_episode_loss}")
                # Add this end-of-episode loss to the total loss.
                # breakpoint()
                # BUG: This can sometimes (rarely) be None! Need to better understand
                # why this is happening.
                if env_episode_loss is None:
                    logger.warning(
                        RuntimeWarning(
                            f"BUG: Env {env_index_in_previous_batch} gave back a loss "
                            f"of `None`, when we expected a loss from that output head "
                            f"for task id {previous_task}."))
                else:
                    total_loss += env_episode_loss
                # We call on_episode_end so the output head can clear the relevant
                # buffers. Note that get_episode_loss(env_index, done=True) doesn't
                # clear the buffers, it just calculates a loss.
                previous_output_head.on_episode_end(
                    env_index_in_previous_batch)

                # Set `done` to `False` for that env, to prevent the output head for the
                # new task from seeing the first observation in the episode as the last.
                observations.done[env_index_in_previous_batch] = False
                # FIXME: If we modify that entry in-place, then even after this method
                # returns, the change will persist.. Therefore we just save the indices
                # that we altered, and reset them before returning.
                done_set_to_false_temporarily_indices.append(
                    env_index_in_previous_batch)
            else:
                raise NotImplementedError(
                    "TODO: The BaselineModel doesn't yet support having multiple "
                    "different tasks within the same batch in RL. ")
                # IDEA: Need to somehow pass the indices of which env to take care of to
                # each output head, so they can create / clear buffers only when needed.

        assert task_labels is not None
        all_task_indices: Dict[int, Tensor] = get_task_indices(task_labels)

        # Get the loss from each output head:
        if len(all_task_indices) == 1:
            # If everything is in the same task (only one key), no need to split/merge
            # stuff, so it's a bit easier:
            task_id: int = list(all_task_indices.keys())[0]

            with self.switch_output_head(task_id):
                # task_output_head = self.output_heads[str(task_id)]
                total_loss += self.output_head.get_loss(
                    forward_pass,
                    actions=actions,
                    rewards=rewards,
                )
        else:
            # Split off the input batch, do a forward pass for each sub-task.
            # (could be done in parallel but whatever.)
            # TODO: Also, not sure if this will play well with DP, DDP, etc.
            for task_id, task_indices in all_task_indices.items():
                # # Make a partial observation without the task labels, so that
                # # super().forward will use the current output head.
                forward_pass_slice = get_slice(forward_pass, task_indices)
                actions_slice = get_slice(actions, task_indices)
                rewards_slice = get_slice(rewards, task_indices)

                logger.debug(
                    f"Getting output head loss"
                    f"{len(task_indices)/batch_size:.0%} of the batch which "
                    f"has task_id of '{task_id}'.")
                task_output_head = self.output_heads[str(task_id)]
                task_loss = task_output_head.get_loss(
                    forward_pass_slice,
                    actions=actions_slice,
                    rewards=rewards_slice,
                )
                # FIXME: debugging
                # task_output_head_loss.name += f"(task {task_id})"
                logger.debug(f"Task {task_id} loss: {task_loss}")
                total_loss += task_loss

        self.previous_task_labels = task_labels
        # FIXME: Reset the 'done' to True, if we manually set it to False.
        for index in done_set_to_false_temporarily_indices:
            observations.done[index] = True

        return total_loss