Exemple #1
0
    def forward(self, uids, X_dict, Y_dict, task_to_label_dict):
        """Calculate the loss, prob for the batch.

        :param uids: The uids of input data
        :type uids: list
        :param X_dict: The input data
        :type X_dict: dict of tensors
        :param Y_dict: The output data
        :type Y_dict: dict of tensors
        :param task_to_label_dict: The task to label mapping
        :type task_to_label_dict: dict
        :return: The (active) uids, loss and prob in the batch of all tasks
        :rtype: dict, dict, dict
        """

        uid_dict = defaultdict(list)
        loss_dict = defaultdict(list)
        prob_dict = defaultdict(list)
        gold_dict = defaultdict(list)

        output_dict = self.flow(X_dict, task_to_label_dict.keys())

        # Calculate loss for each task
        for task_name, label_name in task_to_label_dict.items():
            Y = Y_dict[label_name]

            # Select the active samples
            if len(Y.size()) == 1:
                active = Y.detach(
                ) != Meta.config["learner_config"]["ignore_index"]
            else:
                active = torch.any(
                    Y.detach() !=
                    Meta.config["learner_config"]["ignore_index"],
                    dim=1)

            # Only calculate the loss when active example exists
            if active.any():
                uid_dict[task_name] = [
                    *itertools.compress(uids, active.numpy())
                ]

                loss_dict[task_name] = self.loss_funcs[task_name](
                    output_dict,
                    move_to_device(Y_dict[label_name],
                                   Meta.config["model_config"]["device"]),
                    move_to_device(active,
                                   Meta.config["model_config"]["device"]),
                )

                prob_dict[task_name] = (
                    self.output_funcs[task_name](output_dict)[move_to_device(
                        active, Meta.config["model_config"]
                        ["device"])].cpu().detach().numpy())

                gold_dict[task_name] = Y_dict[label_name][active].cpu().numpy()

        return uid_dict, loss_dict, prob_dict, gold_dict
Exemple #2
0
def test_move_to_device(caplog):
    """Unit test of move_to_device."""
    caplog.set_level(logging.INFO)

    assert torch.equal(move_to_device(torch.Tensor([1, 2]), -1),
                       torch.Tensor([1, 2]))
    assert move_to_device({
        1: torch.tensor([1, 2]),
        2: torch.tensor([3, 4])
    }, -1)
    assert move_to_device([torch.tensor([1, 2]), torch.tensor([3, 4])], -1)
    assert move_to_device((torch.tensor([1, 2]), torch.tensor([3, 4])), -1)
Exemple #3
0
    def flow(self, X_dict: Dict[str, Any], task_names: List[str]) -> Dict[str, Any]:
        """Forward based on input and task flow.

        Note:
          We assume that all shared modules from all tasks are based on the
          same input.

        Args:
          X_dict: The input data
          task_names: The task names that needs to forward.

        Returns:
          The output of all forwarded modules
        """
        default_device = self._get_default_device()

        X_dict = move_to_device(X_dict, default_device)

        output_dict = dict(_input_=X_dict)

        # Call forward for each task
        for task_name in task_names:
            for action in self.task_flows[task_name]:
                if action["name"] not in output_dict:
                    if action["inputs"]:
                        try:
                            action_module_device = (
                                self.module_device[action["module"]]
                                if action["module"] in self.module_device
                                else default_device
                            )
                            input = move_to_device(
                                [
                                    output_dict[action_name][output_index]
                                    for action_name, output_index in action["inputs"]
                                ],
                                action_module_device,
                            )
                        except Exception:
                            raise ValueError(f"Unrecognized action {action}.")
                        output = self.module_pool[action["module"]].forward(*input)
                    else:
                        # TODO: Handle multiple device with not inputs case
                        output = self.module_pool[action["module"]].forward(output_dict)
                    if isinstance(output, tuple):
                        output = list(output)
                    if not isinstance(output, list) and not isinstance(output, dict):
                        output = [output]
                    output_dict[action["name"]] = output

        return output_dict
Exemple #4
0
    def calculate_loss(self, X_dict, Y_dict, task_to_label_dict, data_name, split):
        """Calculate the loss

        :param X_dict: The input data
        :type X_dict: dict of tensors
        :param Y_dict: The output data
        :type Y_dict: dict of tensors
        :param task_to_label_dict: The task to label mapping
        :type task_to_label_dict: dict
        :param data_name: The dataset name
        :type data_name: str
        :param split: The data split
        :type split: str
        :return: The loss and the number of samples in the batch of all tasks
        :rtype: dict, dict
        """

        loss_dict = dict()
        count_dict = dict()

        immediate_ouput_dict = self.forward(X_dict, task_to_label_dict.keys())

        # Calculate loss for each task
        for task_name, label_name in task_to_label_dict.items():
            identifier = "/".join([task_name, data_name, split, "loss"])

            Y = Y_dict[label_name]

            # Select the active samples
            if len(Y.size()) == 1:
                active = Y.detach() != Meta.config["learner_config"]["ignore_index"]
            else:
                active = torch.any(
                    Y.detach() != Meta.config["learner_config"]["ignore_index"], dim=1
                )

            # Only calculate the loss when active example exists
            if active.any():
                count_dict[identifier] = active.sum().item()

                loss_dict[identifier] = self.loss_funcs[task_name](
                    immediate_ouput_dict,
                    move_to_device(
                        Y_dict[label_name], Meta.config["model_config"]["device"]
                    ),
                    move_to_device(active, Meta.config["model_config"]["device"]),
                )

        return loss_dict, count_dict
Exemple #5
0
    def flow(self, X_dict: Dict[str, Any], task_names: List[str]) -> Dict[str, Any]:
        """Forward based on input and task flow.

        Note:
          We assume that all shared modules from all tasks are based on the
          same input.

        Args:
          X_dict: The input data
          task_names: The task names that needs to forward.

        Returns:
          The output of all forwarded modules
        """
        default_device = self._get_default_device()

        X_dict = move_to_device(X_dict, default_device)

        output_dict = dict(_input_=X_dict)

        # Call forward for each task
        for task_name in task_names:
            for action in self.task_flows[task_name]:
                if action.name not in output_dict:
                    if action.inputs:
                        try:
                            action_module_device = (
                                self.module_device[action.module]
                                if action.module in self.module_device
                                else default_device
                            )
                            input = move_to_device(
                                [
                                    self._get_data_from_output_dict(output_dict, _input)
                                    for _input in action.inputs
                                ],
                                action_module_device,
                            )
                        except Exception:
                            raise ValueError(f"Unrecognized action {action}.")
                        output = self.module_pool[action.module].forward(*input)
                    else:
                        # TODO: Handle multiple device with not inputs case
                        output = self.module_pool[action.module].forward(output_dict)
                    output_dict[action.name] = output

        return output_dict
Exemple #6
0
    def flow(self, X_dict: Dict[str, Any],
             task_names: List[str]) -> Dict[str, Any]:
        r"""Forward based on input and task flow.

        Note:
          We assume that all shared modules from all tasks are based on the
          same input.

        Args:
          X_dict(dict): The input data
          task_names(list): The task names that needs to forward.

        Returns:
          dict: The output of all forwarded modules

        """

        X_dict = move_to_device(X_dict, Meta.config["model_config"]["device"])

        output_dict = dict(_input_=X_dict)

        # Call forward for each task
        for task_name in task_names:
            for action in self.task_flows[task_name]:
                if action["name"] not in output_dict:
                    if action["inputs"]:
                        try:
                            input = [
                                output_dict[action_name][output_index] for
                                action_name, output_index in action["inputs"]
                            ]
                        except Exception:
                            raise ValueError(f"Unrecognized action {action}.")
                        output = self.module_pool[action["module"]].forward(
                            *input)
                    else:
                        output = self.module_pool[action["module"]].forward(
                            output_dict)
                    if isinstance(output, tuple):
                        output = list(output)
                    if not isinstance(output, list):
                        output = [output]
                    output_dict[action["name"]] = output

        return output_dict
Exemple #7
0
    def forward(self, X_dict, task_names):
        """Forward based on input and task
            Note: We assume that all shared modules from all tasks are based on the
            the same input.

        :param X_dict: The input data
        :type X_dict: dict of tensor
        :param task_names: The task names that needs to forward
        :type task_names: list of str
        :return: The output of all forwarded modules
        :rtype: dict
        """

        X_dict = move_to_device(X_dict, Meta.config["model_config"]["device"])

        immediate_ouput_dict = dict()
        immediate_ouput_dict["_input_"] = X_dict

        # Call forward for each task
        for task_name in task_names:
            task_flow = self.task_flows[task_name]

            for action in task_flow:
                if action["name"] not in immediate_ouput_dict:
                    if action["inputs"]:
                        try:
                            input = [
                                immediate_ouput_dict[action_name][output_index]
                                for action_name, output_index in action["inputs"]
                            ]
                        except Exception:
                            raise ValueError(f"Unrecognized action {action}.")
                        output = self.module_pool[action["module"]].forward(*input)
                    else:
                        output = self.module_pool[action["module"]].forward(
                            immediate_ouput_dict
                        )
                    if isinstance(output, tuple):
                        output = list(output)
                    if not isinstance(output, list):
                        output = [output]
                    immediate_ouput_dict[action["name"]] = output

        return immediate_ouput_dict
Exemple #8
0
    def forward(  # type: ignore
        self,
        uids: List[str],
        X_dict: Dict[str, Any],
        Y_dict: Dict[str, Tensor],
        task_to_label_dict: Dict[str, str],
        return_probs=True,
        return_action_outputs=False,
    ) -> Union[
        Tuple[
            Dict[str, List[str]],
            Dict[str, Tensor],
            Dict[str, Union[ndarray, List[ndarray]]],
            Dict[str, Union[ndarray, List[ndarray]]],
            Dict[str, Dict[str, Union[ndarray, List]]],
        ],
        Tuple[
            Dict[str, List[str]],
            Dict[str, Tensor],
            Dict[str, Union[ndarray, List[ndarray]]],
            Dict[str, Union[ndarray, List[ndarray]]],
        ],
    ]:
        """Forward function.

        Args:
          uids: The uids of input data.
          X_dict: The input data.
          Y_dict: The output data.
          task_to_label_dict: The task to label mapping.
          return_probs: Whether return prob not, defaults to True.
          return_action_outputs: Whether return action_outputs or not,
          defaults to False.

        Returns:
          The (active) uids, loss, prob, gold, action_output (optional) in the batch of
          all tasks.
        """
        uid_dict: Dict[str, List[str]] = defaultdict(list)
        loss_dict: Dict[str, Tensor] = defaultdict(Tensor)
        gold_dict: Dict[str, Union[ndarray, List[ndarray]]] = defaultdict(list)
        prob_dict: Dict[str, Union[ndarray, List[ndarray]]] = defaultdict(list)
        out_dict: Dict[str, Dict[str, Union[ndarray, List]]] = defaultdict(
            lambda: defaultdict(list)
        )

        task_names = (
            list(task_to_label_dict.keys())
            if isinstance(task_to_label_dict, dict)
            else list(task_to_label_dict)
        )

        output_dict = self.flow(X_dict, task_names)

        if Y_dict is not None:
            # Calculate logit and loss for each task
            for task_name, label_name in task_to_label_dict.items():
                Y = Y_dict[label_name]

                # Select the active samples
                if Meta.config["learner_config"]["ignore_index"] is not None:
                    if len(Y.size()) == 1:
                        active = (
                            Y.detach() != Meta.config["learner_config"]["ignore_index"]
                        )
                    else:
                        active = torch.any(
                            Y.detach() != Meta.config["learner_config"]["ignore_index"],
                            dim=1,
                        )
                else:
                    active = torch.BoolTensor([True] * Y.size()[0])  # type: ignore

                # Only calculate the loss when active example exists
                if active.any():
                    uid_dict[task_name] = [*itertools.compress(uids, active.numpy())]

                    loss_dict[task_name] = self.loss_funcs[task_name](
                        output_dict,
                        move_to_device(
                            Y_dict[label_name], Meta.config["model_config"]["device"]
                        ),
                        move_to_device(active, Meta.config["model_config"]["device"]),
                    )

                    if return_probs:
                        prob_dict[task_name] = (
                            self.output_funcs[task_name](output_dict)[
                                move_to_device(
                                    active, Meta.config["model_config"]["device"]
                                )
                            ]
                            .cpu()
                            .detach()
                            .numpy()
                        )
                    else:
                        prob_dict[task_name] = None

                    gold_dict[task_name] = Y_dict[label_name][active].cpu().numpy()

                    if (
                        return_action_outputs
                        and self.action_outputs[task_name] is not None
                    ):
                        for action_name, output_index in self.action_outputs[task_name]:
                            out_dict[task_name][f"{action_name}_{output_index}"] = (
                                output_dict[action_name][output_index][
                                    move_to_device(
                                        active, Meta.config["model_config"]["device"]
                                    )
                                ]
                                .cpu()
                                .detach()
                                .numpy()
                            )
        else:
            # Calculate logit for each task
            for task_name in task_to_label_dict:
                uid_dict[task_name] = uids
                if return_probs:
                    prob_dict[task_name] = (
                        self.output_funcs[task_name](output_dict).cpu().detach().numpy()
                    )
                else:
                    prob_dict[task_name] = None

                if return_action_outputs and self.action_outputs[task_name] is not None:
                    for action_name, output_index in self.action_outputs[task_name]:
                        out_dict[task_name][f"{action_name}_{output_index}"] = (
                            output_dict[action_name][output_index]
                            .cpu()
                            .detach()
                            .numpy()
                        )
                loss_dict[task_name] = None
                gold_dict[task_name] = None

        if return_action_outputs:
            return uid_dict, loss_dict, prob_dict, gold_dict, out_dict
        else:
            return uid_dict, loss_dict, prob_dict, gold_dict
Exemple #9
0
def build_slice_tasks(
    task: EmmentalTask,
    slice_func_dict: Dict[str, Callable],
    slice_scorer: Optional[Scorer] = None,
    slice_distribution: Dict[str, Tensor] = {},
    dropout: float = 0.0,
    slice_ind_head_module: Optional[nn.Module] = None,
    sep_slice_ind_feature: bool = False,
) -> List[EmmentalTask]:
    """Build slice tasks based on slicing functions.

      We assume the original task flow contains feature extractor and predictor head.
      - The predictor head action should be the last action
      - The feature extractor action should be input of the predictor head action

      For each slicing this function will create two corresponding tasks
      - A slice indicator task to learn whether the data sample is in the slice or not.
      - A slice predictor task that is only learned on the data samples in that slice

      All slice tasks are based on feature extractor module and a slice attention
      module will combine all slice task head to make the final predictions.

    Args:
      task: Task to do slicing learning.
      slice_func_dict: Slicing functions.
      slice_scorer: Slice scorer, defaults to None.
      slice_distribution: Slice data class distribution, defaults to {}.
      dropout: Dropout, defaults to 0.0.
      slice_ind_head_module: Slice indicator head module,
        defaults to None.
      sep_slice_ind_feature: Whether to use sep slice ind feature,
        defaults to False.

    Returns:
      List of tasks.
    """
    # Collect task predictor module info
    base_task_predictor_action = task.task_flow[-1]
    base_task_predictor_module = task.module_pool[base_task_predictor_action.module]
    if isinstance(base_task_predictor_module, nn.DataParallel):
        base_task_predictor_module = base_task_predictor_module.module

    task_feature_size = base_task_predictor_module.in_features
    task_cardinality = base_task_predictor_module.out_features

    # Remove the predictor head module and action
    base_task_module_pool = task.module_pool
    del base_task_module_pool[base_task_predictor_action.module]  # type: ignore

    base_task_task_flow = task.task_flow[:-1]

    tasks = []
    slice_module_pool = nn.ModuleDict()
    for module_name, module in task.module_pool.items():
        slice_module_pool[module_name] = module
    slice_actions = [action for action in base_task_task_flow]

    if slice_ind_head_module is None:
        slice_ind_head_module = nn.Linear(task_feature_size, 2)

    assert isinstance(slice_ind_head_module, nn.Module)

    if slice_scorer is None or not isinstance(slice_scorer, Scorer):
        slice_scorer = Scorer(metrics=["f1"])

    # Create slice indicator tasks.
    # (Note: indicator only has two classes, e.g, in the slice or out)
    for slice_name in slice_func_dict.keys():
        # Create task name
        ind_task_name = f"{task.name}_slice:ind_{slice_name}"

        # Create ind module
        ind_head_module_name = f"{ind_task_name}_head"
        ind_head_module = copy.deepcopy(slice_ind_head_module)

        ind_head_dropout_module_name = f"{task.name}_slice:dropout_{slice_name}"
        ind_head_dropout_module = nn.Dropout(p=dropout)

        # Create module_pool
        ind_module_pool = nn.ModuleDict(
            {
                module_name: module
                for module_name, module in base_task_module_pool.items()
            }
        )
        ind_module_pool[ind_head_dropout_module_name] = ind_head_dropout_module
        ind_module_pool[ind_head_module_name] = ind_head_module

        assert len(base_task_predictor_action.inputs) == 1

        ind_head_dropout_module_input_name = base_task_predictor_action.inputs[0][0]
        ind_head_dropout_module_input_idx = 1 if sep_slice_ind_feature else 0

        # Create task_flow
        ind_task_flow = [action for action in base_task_task_flow]
        ind_task_flow.extend(
            [
                Action(
                    name=ind_head_dropout_module_name,
                    module=ind_head_dropout_module_name,
                    inputs=[
                        (
                            ind_head_dropout_module_input_name,
                            ind_head_dropout_module_input_idx,
                        )
                    ],
                ),
                Action(
                    name=ind_head_module_name,
                    module=ind_head_module_name,
                    inputs=[(ind_head_dropout_module_name, 0)],
                ),
            ]
        )

        # Add slice specific module to slice_module_pool
        slice_module_pool[ind_head_module_name] = ind_head_module
        slice_actions.extend(
            [
                Action(
                    name=ind_head_dropout_module_name,
                    module=ind_head_dropout_module_name,
                    inputs=[
                        (
                            ind_head_dropout_module_input_name,
                            ind_head_dropout_module_input_idx,
                        )
                    ],
                ),
                Action(
                    name=ind_head_module_name,
                    module=ind_head_module_name,
                    inputs=[(ind_head_dropout_module_name, 0)],
                ),
            ]
        )

        # Loss function
        if ind_task_name in slice_distribution:
            loss = partial(
                utils.ce_loss,
                ind_head_module_name,
                weight=move_to_device(
                    slice_distribution[ind_task_name],
                    Meta.config["model_config"]["device"],
                ),
            )
        else:
            loss = partial(utils.ce_loss, ind_head_module_name)

        tasks.append(
            EmmentalTask(
                name=ind_task_name,
                module_pool=ind_module_pool,
                task_flow=ind_task_flow,
                loss_func=loss,
                output_func=partial(utils.output, ind_head_module_name),
                scorer=slice_scorer,
            )
        )

    # Create slice predictor tasks

    # Create share predictor for all slice predictor
    shared_pred_head_module_name = f"{task.name}_slice:shared_pred"
    shared_pred_head_module = nn.Linear(task_feature_size, task_cardinality)

    # Add slice specific module to slice_module_pool
    slice_module_pool[shared_pred_head_module_name] = shared_pred_head_module

    for slice_name in slice_func_dict.keys():
        # Create task name
        pred_task_name = f"{task.name}_slice:pred_{slice_name}"

        # Create pred module
        pred_head_module_name = f"{pred_task_name}_head"
        pred_transform_module_name = f"{task.name}_slice:transform_{slice_name}"
        pred_transform_module = nn.Linear(task_feature_size, task_feature_size)

        # Create module_pool
        pred_module_pool = nn.ModuleDict(
            {
                module_name: module
                for module_name, module in base_task_module_pool.items()
            }
        )
        pred_module_pool[pred_transform_module_name] = pred_transform_module
        pred_module_pool[shared_pred_head_module_name] = shared_pred_head_module

        # Create task_flow
        pred_task_flow = [action for action in base_task_task_flow]
        pred_task_flow.extend(
            [
                Action(
                    name=pred_transform_module_name,
                    module=pred_transform_module_name,
                    inputs=base_task_predictor_action.inputs,
                ),
                Action(
                    name=pred_head_module_name,
                    module=shared_pred_head_module_name,
                    inputs=[(pred_transform_module_name, 0)],
                ),
            ]
        )

        # Add slice specific module to slice_module_pool
        slice_module_pool[pred_transform_module_name] = pred_transform_module
        slice_actions.extend(
            [
                Action(
                    name=pred_transform_module_name,
                    module=pred_transform_module_name,
                    inputs=base_task_predictor_action.inputs,
                ),
                Action(
                    name=pred_head_module_name,
                    module=shared_pred_head_module_name,
                    inputs=[(pred_transform_module_name, 0)],
                ),
            ]
        )

        # Loss function
        if pred_task_name in slice_distribution:
            loss = partial(
                utils.ce_loss,
                pred_head_module_name,
                weight=move_to_device(
                    slice_distribution[pred_task_name],
                    Meta.config["model_config"]["device"],
                ),
            )
        else:
            loss = partial(utils.ce_loss, pred_head_module_name)

        tasks.append(
            EmmentalTask(
                name=pred_task_name,
                module_pool=pred_module_pool,
                task_flow=pred_task_flow,
                loss_func=loss,
                output_func=partial(utils.output, pred_head_module_name),
                scorer=task.scorer,
            )
        )

    # Create master task

    # Create task name
    master_task_name = task.name

    # Create attention module
    master_attention_module_name = f"{master_task_name}_attention"
    master_attention_module = SliceAttentionModule(
        slice_ind_key="_slice:ind_",
        slice_pred_key="_slice:pred_",
        slice_pred_feat_key="_slice:transform_",
    )

    # Create module pool
    master_head_module_name = f"{master_task_name}_head"
    master_head_module = base_task_predictor_module

    master_module_pool = slice_module_pool
    master_module_pool[master_attention_module_name] = master_attention_module
    master_module_pool[master_head_module_name] = master_head_module

    # Create task_flow
    master_task_flow = slice_actions + [
        Action(
            name=master_attention_module_name,
            module=master_attention_module_name,
            inputs=[],  # type: ignore
        ),
        Action(
            name=master_head_module_name,
            module=master_head_module_name,
            inputs=[(master_attention_module_name, 0)],
        ),
    ]

    tasks.append(
        EmmentalTask(
            name=master_task_name,
            module_pool=master_module_pool,
            task_flow=master_task_flow,
            loss_func=partial(utils.ce_loss, master_head_module_name),
            output_func=partial(utils.output, master_head_module_name),
            scorer=task.scorer,
        )
    )

    return tasks
Exemple #10
0
    def forward(
        self,
        uids: List[str],
        X_dict: Dict[str, Any],
        Y_dict: Dict[str, Tensor],
        task_to_label_dict: Dict[str, str],
    ) -> Tuple[Dict[str, List[str]], Dict[str, ndarray], Dict[str, ndarray],
               Dict[str, ndarray]]:
        r"""Forward function.

        Args:
          uids(list): The uids of input data.
          X_dict(dict): The input data.
          Y_dict(dict): The output data.
          task_to_label_dict(dict): The task to label mapping.

        Returns:
          tuple: The (active) uids, loss and prob in the batch of all tasks.

        """

        uid_dict: Dict[str, List[str]] = defaultdict(list)
        loss_dict: Dict[str, ndarray] = defaultdict(float)
        gold_dict: Dict[str, ndarray] = defaultdict(list)
        prob_dict: Dict[str, ndarray] = defaultdict(list)

        output_dict = self.flow(X_dict, list(task_to_label_dict.keys()))

        # Calculate loss for each task
        for task_name, label_name in task_to_label_dict.items():
            Y = Y_dict[label_name]

            # Select the active samples
            if Meta.config["learner_config"]["ignore_index"] is not None:
                if len(Y.size()) == 1:
                    active = Y.detach(
                    ) != Meta.config["learner_config"]["ignore_index"]
                else:
                    active = torch.any(
                        Y.detach() !=
                        Meta.config["learner_config"]["ignore_index"],
                        dim=1,
                    )
            else:
                active = torch.ByteTensor([True] * Y.size()[0])

            # Only calculate the loss when active example exists
            if active.any():
                uid_dict[task_name] = [
                    *itertools.compress(uids, active.numpy())
                ]

                loss_dict[task_name] = self.loss_funcs[task_name](
                    output_dict,
                    move_to_device(Y_dict[label_name],
                                   Meta.config["model_config"]["device"]),
                    move_to_device(active,
                                   Meta.config["model_config"]["device"]),
                )

                prob_dict[task_name] = (
                    self.output_funcs[task_name](output_dict)[move_to_device(
                        active, Meta.config["model_config"]
                        ["device"])].cpu().detach().numpy())

                gold_dict[task_name] = Y_dict[label_name][active].cpu().numpy()

        return uid_dict, loss_dict, prob_dict, gold_dict
Exemple #11
0
    def forward(  # type: ignore
        self,
        uids: List[str],
        X_dict: Dict[str, Any],
        Y_dict: Dict[str, Tensor],
        task_to_label_dict: Dict[str, str],
        return_loss=True,
        return_probs=True,
        return_action_outputs=False,
    ) -> Union[
        Tuple[
            Dict[str, List[str]],
            Dict[str, Tensor],
            Dict[str, Union[ndarray, List[ndarray]]],
            Dict[str, Union[ndarray, List[ndarray]]],
            Dict[str, Dict[str, Union[ndarray, List]]],
        ],
        Tuple[
            Dict[str, List[str]],
            Dict[str, Tensor],
            Dict[str, Union[ndarray, List[ndarray]]],
            Dict[str, Union[ndarray, List[ndarray]]],
        ],
    ]:
        """Forward function.

        Args:
          uids: The uids of input data.
          X_dict: The input data.
          Y_dict: The output data.
          task_to_label_dict: The task to label mapping.
          return_loss: Whether return loss or not, defaults to True.
          return_probs: Whether return probs or not, defaults to True.
          return_action_outputs: Whether return action_outputs or not,
          defaults to False.

        Returns:
          The uids, loss, prob, gold, action_output (optional) in the batch of
          all tasks.
        """
        uid_dict: Dict[str, List[str]] = defaultdict(list)
        loss_dict: Dict[str, Tensor] = defaultdict(Tensor) if return_loss else None
        gold_dict: Dict[str, Union[ndarray, List[ndarray]]] = (
            defaultdict(list) if Y_dict is not None else None
        )
        prob_dict: Dict[str, Union[ndarray, List[ndarray]]] = (
            defaultdict(list) if return_probs else None
        )
        out_dict: Dict[str, Dict[str, Union[ndarray, List]]] = (
            defaultdict(lambda: defaultdict(list)) if return_action_outputs else None
        )

        output_dict = self.flow(X_dict, list(task_to_label_dict.keys()))

        # Calculate logits and loss for each task
        for task_name, label_name in task_to_label_dict.items():
            assert Y_dict is not None or (
                Y_dict is None and label_name is None
            ), f"Task {task_name} has not {label_name} label."

            uid_dict[task_name] = uids

            if (
                return_loss
                and task_name in self.loss_funcs
                and self.loss_funcs[task_name] is not None
            ):
                loss_dict[task_name] = self.loss_funcs[task_name](
                    output_dict,
                    move_to_device(
                        Y_dict[label_name],
                        Meta.config["model_config"]["device"],
                    )
                    if Y_dict is not None and label_name is not None
                    else None,
                )

            if (
                return_probs
                and task_name in self.output_funcs
                and self.output_funcs[task_name] is not None
            ):
                prob_dict[task_name] = (
                    self.output_funcs[task_name](output_dict).cpu().detach().numpy()
                )

            if Y_dict is not None and label_name is not None:
                gold_dict[task_name] = Y_dict[label_name].cpu().numpy()

            if (
                return_action_outputs
                and task_name in self.action_outputs
                and self.action_outputs[task_name] is not None
            ):
                for _output in self.action_outputs[task_name]:
                    out_dict[task_name][
                        _output
                        if isinstance(_output, str)
                        else f"{_output[0]}_{_output[1]}"
                    ] = (
                        self._get_data_from_output_dict(output_dict, _output)
                        .cpu()
                        .detach()
                        .numpy()
                    )

        if return_action_outputs:
            return uid_dict, loss_dict, prob_dict, gold_dict, out_dict
        else:
            return uid_dict, loss_dict, prob_dict, gold_dict