Пример #1
0
  def mixsent_forward(self, *inputs, **kwargs):
    task_name = kwargs.pop("task_name")
    teacher_prediction = kwargs.pop("teacher_predictions", None)
    if teacher_prediction is None:
      # Run teacher Part
      results = {}
      arguments_a = self.get_arguments(prefix="a_", kwargs=kwargs)
      features_a = self.encoding(**arguments_a)
      arguments_b = self.get_arguments(prefix="b_", kwargs=kwargs)
      features_b = self.encoding(**arguments_b)
      for task in self.tasks:
        if task.name != task_name:
          if isinstance(features_a, list):
            features_a = features_a[0]
            features_b = features_b[0]
          # else:
          #   features = features
          result_a = self.decoding(**arguments_a, **kwargs, features=features_a, task_name=task.name)
          result_b = self.decoding(**arguments_b, **kwargs, features=features_b, task_name=task.name)
          # TODO: Detach !!!
          results[task.name] = {
            "a": result_a.detach(),
            "b": result_b.detach()
          }
      return results

    else:
      # Run student and compute the loss
      arguments = self.get_arguments(prefix="c_", kwargs=kwargs)
      features = self.encoding(**arguments)
      losses = {}
      for task in self.tasks:
        if task.name != task_name:
          if isinstance(features, list):
            features = features[0]
          result = self.decoding(**arguments, **kwargs, features=features, task_name=task.name)
          logits, target, mask = self.decoding(
            task_name=task_name,
            student_results=result,
            teacher_results=teacher_prediction[task.name],
            **arguments, **kwargs)
          loss = get_loss(
            task=self.task_dict[task_name],
            logits=logits,
            config=self.config,
            target=target,
            mask=mask,
            **arguments, **kwargs)
          losses[task.name] = loss
      return {"loss": reduce(lambda x,y: x+y, losses.values())}
Пример #2
0
    def siamese_forward(self, *inputs, **kwargs):
        task_name = kwargs.pop("task_name")

        a_arguments = self.get_arguments(prefix="a_", kwargs=kwargs)
        b_arguments = self.get_arguments(prefix="b_", kwargs=kwargs)

        a_encoding_results = self.encoding(**a_arguments)
        b_encoding_results = self.encoding(**b_arguments)
        a_features = a_encoding_results["features"]
        b_features = b_encoding_results["features"]

        logits = self.decoding(task_name=task_name,
                               a_features=a_features,
                               b_features=b_features,
                               a_encoding_results=a_encoding_results,
                               b_encoding_results=b_encoding_results,
                               **kwargs)

        # loss = (logits * logits).sum() / logits.size(0)
        label_ids = kwargs.get("label_ids", None)
        _label_ids = kwargs.get("_label_ids", None)

        outputs_dict = {}

        if label_ids is not None or _label_ids is not None:
            if task_name not in SKIP_LOSS_TASK:
                loss = get_loss(task=self.task_dict[task_name],
                                logits=logits,
                                config=self.config,
                                a_arguments=a_arguments,
                                b_arguments=b_arguments,
                                **kwargs)
                outputs_dict = {"loss": loss, "logits": logits}
            else:
                outputs_dict = {"logits": logits}
        else:
            if task_name in SKIP_LOSS_TASK:
                outputs_dict = {"logits": logits}
            else:
                outputs_dict = {"logits": logits.detach()}

        return outputs_dict
Пример #3
0
    def forward(self, *inputs, **kwargs):
        task_name = kwargs.pop("task_name")

        outputs_dict = {}

        # parallel teacher student task will use its own outputs_dict
        # It will iterate on each task, and get its output from decoder,
        # and save it on the `results`. The key is the task_name and value
        # is the tensor.
        if task_name == PARALLEL_TEACHER_STUDENT_TASK:
            teacher_prediction = kwargs.pop("teacher_predictions", None)
            if teacher_prediction is None:
                # Run teacher Part
                results = {}
                arguments = self.get_arguments(prefix="a_", kwargs=kwargs)
                encoding_results = self.encoding(**arguments)
                for task in self.tasks:
                    if task.name != task_name:
                        result = encoding_results["features"]
                        results[task.name] = result.detach()  # result.detach()
                return results

            else:
                arguments = self.get_arguments(prefix="b_", kwargs=kwargs)
                encoding_results = self.encoding(**arguments)

                losses = {}
                for task in self.tasks:
                    if task.name != task_name:
                        result = encoding_results["features"]
                        logits, target, mask = self.decoding(
                            task_name=task_name,
                            student_results=result,
                            teacher_results=teacher_prediction[task.name],
                            **arguments,
                            **kwargs)
                        loss = get_loss(task=self.task_dict[task_name],
                                        logits=logits,
                                        config=self.config,
                                        target=target,
                                        mask=mask,
                                        **arguments,
                                        **kwargs)
                        losses[task.name] = loss
                return reduce(lambda x, y: x + y, losses.values()), None

        elif task_name in SIAMESE:
            outputs_dict[task_name] = self.siamese_forward(*inputs,
                                                           **kwargs,
                                                           task_name=task_name)

        elif task_name == MIXSENT_TASK:
            outputs_dict[task_name] = self.mixsent_forward(*inputs,
                                                           **kwargs,
                                                           task_name=task_name)

        elif task_name in TRIPLET:
            outputs_dict[task_name] = self.triplet_forward(*inputs,
                                                           **kwargs,
                                                           task_name=task_name)

            # elif task_name in [DEP_PARSING_TASK]:
            #   arguments = self.get_arguments(prefix="", kwargs=kwargs)
            #   features = self.encoding(**arguments)
            #   outputs = self.decoding(**arguments, **kwargs, features=features, task_name=task_name)
            #   outputs_dict[task_name] = outputs

            # elif task_name == ENCODING_TASK:
            #   # For encoding task, for now, we only consider singleton.
            #   arguments = self.get_arguments(prefix="", kwargs=kwargs)
            #   features = self.encoder(**arguments)
            #   return {"features": features}

        else:
            # Currently we can only one-forward-multitask here
            # And the interface is not compatible with dependency parsing and some other multi-outputs task.
            # TODO: Restructure the inference to fix
            arguments = self.get_arguments(prefix="", kwargs=kwargs)
            encoding_results = self.encoding(**arguments)
            features = encoding_results.pop("features")
            kwargs.pop("extra_args", None)
            task_names = task_name.split(',')
            for task_name in task_names:
                logits = self.decoding(**arguments,
                                       **kwargs,
                                       features=features,
                                       task_name=task_name,
                                       encoding_results=encoding_results)
                label_ids = kwargs.get("label_ids", None)
                _label_ids = kwargs.get("_label_ids", None)
                if label_ids is not None or _label_ids is not None:
                    if task_name not in SKIP_LOSS_TASK:
                        # if task_name in ["joint_srl"]:
                        #   loss = self.task_dict[task_name].compute_loss()
                        # else:
                        loss = get_loss(task=self.task_dict[task_name],
                                        logits=logits,
                                        config=self.config,
                                        **arguments,
                                        **kwargs)
                        outputs_dict[task_name] = {
                            "loss": loss,
                            "logits": logits
                        }
                    else:
                        outputs_dict[task_name] = {"logits": logits}
                else:
                    if task_name in SKIP_LOSS_TASK:
                        outputs_dict[task_name] = {"logits": logits}
                    else:
                        outputs_dict[task_name] = {"logits": logits.detach()}
                if self.config.output_features:
                    outputs_dict[task_name]["features"] = logits
            if self.config.output_attentions:
                outputs_dict[task_name]["attention_map"] = encoding_results[
                    "attention_map"]

        return outputs_dict