def merge_pos_into_con(doc: Document): flat = isinstance(doc['pos'][0], str) if flat: doc = Document((k, [v]) for k, v in doc.items()) for tree, tags in zip(doc['con'], doc['pos']): offset = 0 for subtree in tree.subtrees(lambda t: t.height() == 2): tag = subtree.label() if tag == '_': subtree.set_label(tags[offset]) offset += 1 if flat: doc = doc.squeeze() return doc
def predict(self, data: Union[str, List[str]], batch_size: int = None, tasks: Optional[Union[str, List[str]]] = None, skip_tasks: Optional[Union[str, List[str]]] = None, resolved_tasks=None, **kwargs) -> Document: """Predict on data. Args: data: A sentence or a list of sentences. batch_size: Decoding batch size. tasks: The tasks to predict. skip_tasks: The tasks to skip. resolved_tasks: The resolved tasks to override ``tasks`` and ``skip_tasks``. **kwargs: Not used. Returns: A :class:`~hanlp_common.document.Document`. """ doc = Document() if not data: return doc target_tasks = resolved_tasks or self.resolve_tasks(tasks, skip_tasks) flatten_target_tasks = [ self.tasks[t] for group in target_tasks for t in group ] cls_is_bos = any([x.cls_is_bos for x in flatten_target_tasks]) sep_is_eos = any([x.sep_is_eos for x in flatten_target_tasks]) # Now build the dataloaders and execute tasks first_task_name: str = list(target_tasks[0])[0] first_task: Task = self.tasks[first_task_name] encoder_transform, transform = self.build_transform(first_task) # Override the tokenizer config of the 1st task encoder_transform.sep_is_eos = sep_is_eos encoder_transform.cls_is_bos = cls_is_bos average_subwords = self.model.encoder.average_subwords flat = first_task.input_is_flat(data) if flat: data = [data] device = self.device samples = first_task.build_samples(data, cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos) dataloader = first_task.build_dataloader(samples, transform=transform, device=device) results = defaultdict(list) order = [] for batch in dataloader: order.extend(batch[IDX]) # Run the first task, let it make the initial batch for the successors output_dict = self.predict_task(first_task, first_task_name, batch, results, run_transform=True, cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos) # Run each task group in order for group_id, group in enumerate(target_tasks): # We could parallelize this in the future for task_name in group: if task_name == first_task_name: continue output_dict = self.predict_task(self.tasks[task_name], task_name, batch, results, output_dict, run_transform=True, cls_is_bos=cls_is_bos, sep_is_eos=sep_is_eos) if group_id == 0: # We are kind of hard coding here. If the first task is a tokenizer, # we need to convert the hidden and mask to token level if first_task_name.startswith('tok'): spans = [] tokens = [] for span_per_sent, token_per_sent in zip( output_dict[first_task_name]['prediction'], results[first_task_name][-len(batch[IDX]):]): if cls_is_bos: span_per_sent = [(-1, 0)] + span_per_sent token_per_sent = [BOS] + token_per_sent if sep_is_eos: span_per_sent = span_per_sent + [ (span_per_sent[-1][0] + 1, span_per_sent[-1][1] + 1) ] token_per_sent = token_per_sent + [EOS] # The offsets start with 0 while [CLS] is zero if average_subwords: span_per_sent = [ list(range(x[0] + 1, x[1] + 1)) for x in span_per_sent ] else: span_per_sent = [ x[0] + 1 for x in span_per_sent ] spans.append(span_per_sent) tokens.append(token_per_sent) spans = PadSequenceDataLoader.pad_data(spans, 0, torch.long, device=device) output_dict['hidden'] = pick_tensor_for_each_token( output_dict['hidden'], spans, average_subwords) batch['token_token_span'] = spans batch['token'] = tokens # noinspection PyTypeChecker batch['token_length'] = torch.tensor( [len(x) for x in tokens], dtype=torch.long, device=device) batch.pop('mask', None) # Put results into doc in the order of tasks for k in self.config.task_names: v = results.get(k, None) if v is None: continue doc[k] = reorder(v, order) # Allow task to perform finalization on document for group in target_tasks: for task_name in group: task = self.tasks[task_name] task.finalize_document(doc, task_name) # If no tok in doc, use raw input as tok if not any(k.startswith('tok') for k in doc): doc['tok'] = data if flat: for k, v in list(doc.items()): doc[k] = v[0] # If there is only one field, don't bother to wrap it # if len(doc) == 1: # return list(doc.values())[0] return doc