예제 #1
0
def merge_pos_into_con(doc: Document):
    flat = isinstance(doc['pos'][0], str)
    if flat:
        doc = Document((k, [v]) for k, v in doc.items())
    for tree, tags in zip(doc['con'], doc['pos']):
        offset = 0
        for subtree in tree.subtrees(lambda t: t.height() == 2):
            tag = subtree.label()
            if tag == '_':
                subtree.set_label(tags[offset])
            offset += 1
    if flat:
        doc = doc.squeeze()
    return doc
예제 #2
0
    def predict(self,
                data: Union[str, List[str]],
                batch_size: int = None,
                tasks: Optional[Union[str, List[str]]] = None,
                skip_tasks: Optional[Union[str, List[str]]] = None,
                resolved_tasks=None,
                **kwargs) -> Document:
        """Predict on data.

        Args:
            data: A sentence or a list of sentences.
            batch_size: Decoding batch size.
            tasks: The tasks to predict.
            skip_tasks: The tasks to skip.
            resolved_tasks: The resolved tasks to override ``tasks`` and ``skip_tasks``.
            **kwargs: Not used.

        Returns:
            A :class:`~hanlp_common.document.Document`.
        """
        doc = Document()
        if not data:
            return doc

        target_tasks = resolved_tasks or self.resolve_tasks(tasks, skip_tasks)
        flatten_target_tasks = [
            self.tasks[t] for group in target_tasks for t in group
        ]
        cls_is_bos = any([x.cls_is_bos for x in flatten_target_tasks])
        sep_is_eos = any([x.sep_is_eos for x in flatten_target_tasks])
        # Now build the dataloaders and execute tasks
        first_task_name: str = list(target_tasks[0])[0]
        first_task: Task = self.tasks[first_task_name]
        encoder_transform, transform = self.build_transform(first_task)
        # Override the tokenizer config of the 1st task
        encoder_transform.sep_is_eos = sep_is_eos
        encoder_transform.cls_is_bos = cls_is_bos
        average_subwords = self.model.encoder.average_subwords
        flat = first_task.input_is_flat(data)
        if flat:
            data = [data]
        device = self.device
        samples = first_task.build_samples(data,
                                           cls_is_bos=cls_is_bos,
                                           sep_is_eos=sep_is_eos)
        dataloader = first_task.build_dataloader(samples,
                                                 transform=transform,
                                                 device=device)
        results = defaultdict(list)
        order = []
        for batch in dataloader:
            order.extend(batch[IDX])
            # Run the first task, let it make the initial batch for the successors
            output_dict = self.predict_task(first_task,
                                            first_task_name,
                                            batch,
                                            results,
                                            run_transform=True,
                                            cls_is_bos=cls_is_bos,
                                            sep_is_eos=sep_is_eos)
            # Run each task group in order
            for group_id, group in enumerate(target_tasks):
                # We could parallelize this in the future
                for task_name in group:
                    if task_name == first_task_name:
                        continue
                    output_dict = self.predict_task(self.tasks[task_name],
                                                    task_name,
                                                    batch,
                                                    results,
                                                    output_dict,
                                                    run_transform=True,
                                                    cls_is_bos=cls_is_bos,
                                                    sep_is_eos=sep_is_eos)
                if group_id == 0:
                    # We are kind of hard coding here. If the first task is a tokenizer,
                    # we need to convert the hidden and mask to token level
                    if first_task_name.startswith('tok'):
                        spans = []
                        tokens = []
                        for span_per_sent, token_per_sent in zip(
                                output_dict[first_task_name]['prediction'],
                                results[first_task_name][-len(batch[IDX]):]):
                            if cls_is_bos:
                                span_per_sent = [(-1, 0)] + span_per_sent
                                token_per_sent = [BOS] + token_per_sent
                            if sep_is_eos:
                                span_per_sent = span_per_sent + [
                                    (span_per_sent[-1][0] + 1,
                                     span_per_sent[-1][1] + 1)
                                ]
                                token_per_sent = token_per_sent + [EOS]
                            # The offsets start with 0 while [CLS] is zero
                            if average_subwords:
                                span_per_sent = [
                                    list(range(x[0] + 1, x[1] + 1))
                                    for x in span_per_sent
                                ]
                            else:
                                span_per_sent = [
                                    x[0] + 1 for x in span_per_sent
                                ]
                            spans.append(span_per_sent)
                            tokens.append(token_per_sent)
                        spans = PadSequenceDataLoader.pad_data(spans,
                                                               0,
                                                               torch.long,
                                                               device=device)
                        output_dict['hidden'] = pick_tensor_for_each_token(
                            output_dict['hidden'], spans, average_subwords)
                        batch['token_token_span'] = spans
                        batch['token'] = tokens
                        # noinspection PyTypeChecker
                        batch['token_length'] = torch.tensor(
                            [len(x) for x in tokens],
                            dtype=torch.long,
                            device=device)
                        batch.pop('mask', None)
        # Put results into doc in the order of tasks
        for k in self.config.task_names:
            v = results.get(k, None)
            if v is None:
                continue
            doc[k] = reorder(v, order)
        # Allow task to perform finalization on document
        for group in target_tasks:
            for task_name in group:
                task = self.tasks[task_name]
                task.finalize_document(doc, task_name)
        # If no tok in doc, use raw input as tok
        if not any(k.startswith('tok') for k in doc):
            doc['tok'] = data
        if flat:
            for k, v in list(doc.items()):
                doc[k] = v[0]
        # If there is only one field, don't bother to wrap it
        # if len(doc) == 1:
        #     return list(doc.values())[0]
        return doc