示例#1
0
 def post_outputs(self, predictions, data, order, use_pos, build_data):
     predictions = reorder(predictions, order)
     if build_data:
         data = reorder(data, order)
     outputs = []
     self.predictions_to_human(predictions, outputs, data, use_pos)
     return outputs
示例#2
0
 def predict(self,
             data: Union[List[str], List[List[str]]],
             batch_size: int = None,
             ret_tokens=True,
             **kwargs):
     if not data:
         return []
     flat = self.input_is_flat(data)
     if flat:
         data = [data]
     dataloader = self.build_dataloader([{
         'token': x
     } for x in data], batch_size, False, self.device)
     predictions = []
     orders = []
     for batch in dataloader:
         output_dict = self.feed_batch(batch)
         token = batch['token']
         prediction = output_dict['prediction']
         self.prediction_to_result(token, prediction, predictions,
                                   ret_tokens)
         orders.extend(batch[IDX])
     predictions = reorder(predictions, orders)
     if flat:
         return predictions[0]
     return predictions
示例#3
0
 def predict(self,
             data: Union[List[str], List[List[str]]],
             batch_size: int = None,
             **kwargs):
     if not data:
         return []
     flat = self.input_is_flat(data)
     if flat:
         data = [data]
     samples = self.build_samples(data)
     if not batch_size:
         batch_size = self.config.batch_size
     dataloader = self.build_dataloader(samples,
                                        device=self.devices[0],
                                        shuffle=False,
                                        **merge_dict(self.config,
                                                     batch_size=batch_size,
                                                     overwrite=True,
                                                     **kwargs))
     order = []
     outputs = []
     for batch in dataloader:
         out, mask = self.feed_batch(batch)
         self.decode_output(out, mask, batch)
         outputs.extend(self.prediction_to_human(out, batch))
         order.extend(batch[IDX])
     outputs = reorder(outputs, order)
     if flat:
         return outputs[0]
     return outputs
 def predict(self,
             data: Union[str, List[str]],
             batch_size: int = None,
             **kwargs):
     if not data:
         return []
     flat = self.input_is_flat(data)
     if flat:
         data = [data]
     samples = self.build_samples(data)
     dataloader = self.build_dataloader(samples,
                                        device=self.device,
                                        **merge_dict(self.config,
                                                     batch_size=batch_size,
                                                     overwrite=True))
     outputs = []
     orders = []
     for idx, batch in enumerate(dataloader):
         out, mask = self.feed_batch(batch)
         prediction = self.decode_output(out, mask, batch, span_probs=None)
         # prediction = [x[0] for x in prediction]
         outputs.extend(prediction)
         orders.extend(batch[IDX])
     outputs = reorder(outputs, orders)
     if flat:
         return outputs[0]
     return outputs
示例#5
0
 def predict(self,
             data: Union[str, List[str]],
             batch_size: int = None,
             fmt='dict',
             **kwargs):
     if not data:
         return []
     flat = self.input_is_flat(data)
     if flat:
         data = [data]
     samples = []
     for token in data:
         sample = dict()
         sample['token'] = token
         samples.append(sample)
     batch_size = batch_size or self.config.batch_size
     dataloader = self.build_dataloader(samples,
                                        batch_size,
                                        False,
                                        self.device,
                                        None,
                                        generate_idx=True)
     outputs = []
     order = []
     for batch in dataloader:
         output_dict = self.feed_batch(batch)
         outputs.extend(output_dict['prediction'])
         order.extend(batch[IDX])
     outputs = reorder(outputs, order)
     if fmt == 'list':
         outputs = self.format_dict_to_results(data, outputs)
     if flat:
         return outputs[0]
     return outputs
示例#6
0
 def evaluate_dataloader(self,
                         data: DataLoader,
                         logger: logging.Logger,
                         metric=None,
                         output=False,
                         **kwargs):
     self.model.eval()
     timer = CountdownTimer(len(data))
     total_loss = 0
     metric.reset()
     if output:
         predictions = []
         orders = []
         samples = []
     for batch in data:
         output_dict = self.feed_batch(batch)
         prediction = self.decode(output_dict)
         metric(prediction, batch['similarity'])
         if output:
             predictions.extend(prediction.tolist())
             orders.extend(batch[IDX])
             samples.extend(list(zip(batch['sent_a'], batch['sent_b'])))
         loss = output_dict['loss']
         total_loss += loss.item()
         timer.log(self.report_metrics(total_loss / (timer.current + 1),
                                       metric),
                   ratio_percentage=None,
                   logger=logger)
         del loss
     if output:
         predictions = reorder(predictions, orders)
         samples = reorder(samples, orders)
         with open(output, 'w') as out:
             for s, p in zip(samples, predictions):
                 out.write('\t'.join(s + (str(p), )))
                 out.write('\n')
     return total_loss / timer.total
示例#7
0
 def predict(self, data: Union[str, List[str]], batch_size: int = None, **kwargs):
     if not data:
         return []
     flat = self.input_is_flat(data)
     if flat:
         data = [data]
     dataloader = self.build_dataloader(self.build_samples(data), batch_size, device=self.device, **kwargs)
     results = []
     order = []
     for batch in dataloader:
         pred, mask = self.feed_batch(batch)
         prediction = self.decode_output(pred, mask, batch)
         results.extend(self.prediction_to_result(prediction, batch))
         order.extend(batch[IDX])
     results = reorder(results, order)
     if flat:
         return results[0]
     return results
示例#8
0
 def predict(self, data: Union[str, List[str]], beautiful_amr_graph=True, **kwargs):
     flat = self.input_is_flat(data)
     if flat:
         data = [data]
     dataloader = self.build_dataloader([{'text': x} for x in data], **self.config, device=self.device)
     orders = []
     results = []
     for batch in dataloader:
         graphs = self.predict_amrs(batch)
         graphs = [x[0] for x in graphs]
         if beautiful_amr_graph:
             graphs = [AMRGraph(x.triples, x.top, x.epidata, x.metadata) for x in graphs]
         results.extend(graphs)
         orders.extend(batch[IDX])
     results = reorder(results, orders)
     if flat:
         results = results[0]
     return results
示例#9
0
 def predict_data(self, data, batch_size, **kwargs):
     samples = self.build_samples(data, **kwargs)
     if not batch_size:
         batch_size = self.config.get('batch_size', 32)
     dataloader = self.build_dataloader(samples, batch_size, False,
                                        self.device)
     outputs = []
     orders = []
     vocab = self.vocabs['tag'].idx_to_token
     for batch in dataloader:
         out, mask = self.feed_batch(batch)
         pred = self.decode_output(out, mask, batch)
         if isinstance(pred, torch.Tensor):
             pred = pred.tolist()
         outputs.extend(self.prediction_to_human(pred, vocab, batch))
         orders.extend(batch[IDX])
     outputs = reorder(outputs, orders)
     return outputs
示例#10
0
文件: mlm.py 项目: lei1993/HanLP
 def predict(self,
             masked_sents: Union[str, List[str]],
             batch_size=32,
             topk=10,
             **kwargs):
     flat = self.input_is_flat(masked_sents)
     if flat:
         masked_sents = [masked_sents]
     dataloader = self.build_dataloader(masked_sents,
                                        **self.config,
                                        device=self.device,
                                        batch_size=batch_size)
     orders = []
     results = []
     for batch in dataloader:
         input_ids = batch['token_input_ids']
         outputs = self.model(input_ids=input_ids,
                              attention_mask=batch['token_attention_mask'])
         mask = input_ids == self.tokenizer.mask_token_id
         if mask.any():
             num_masks = mask.sum(dim=-1).tolist()
             masked_logits = outputs.logits[mask]
             masked_logits[:, self.tokenizer.all_special_ids] = -math.inf
             probs, indices = torch.nn.functional.softmax(masked_logits,
                                                          dim=-1).topk(topk)
             br = []
             for p, index in zip(probs.tolist(), indices.tolist()):
                 br.append(
                     dict(
                         zip(self.tokenizer.convert_ids_to_tokens(index),
                             p)))
             offset = 0
             for n in num_masks:
                 results.append(br[offset:offset + n])
                 offset += n
         else:
             results.extend([[]] * input_ids.size(0))
         orders.extend(batch[IDX])
     results = reorder(results, orders)
     if flat:
         results = results[0]
     return results
示例#11
0
    def predict(self,
                data: Union[List[str], List[List[str]]],
                batch_size: int = None,
                **kwargs) -> Union[float, List[float]]:
        """ Predict the similarity between sentence pairs.

        Args:
            data: Sentence pairs.
            batch_size: The number of samples in a batch.
            **kwargs: Not used.

        Returns:
            Similarities between sentences.
        """
        if not data:
            return []
        flat = isinstance(data[0], str)
        if flat:
            data = [data]
        dataloader = self.build_dataloader([{
            'sent_a': x[0],
            'sent_b': x[1]
        } for x in data],
                                           batch_size=batch_size
                                           or self.config.batch_size,
                                           device=self.device)
        orders = []
        predictions = []
        for batch in dataloader:
            output_dict = self.feed_batch(batch)
            prediction = self.decode(output_dict)
            predictions.extend(prediction.tolist())
            orders.extend(batch[IDX])
        predictions = reorder(predictions, orders)
        if flat:
            return predictions[0]
        return predictions
示例#12
0
    def evaluate_dataloader(self, data: DataLoader, criterion: Callable, metric=None, output=False, ratio_width=None,
                            logger=None, input=None, use_fast=False,
                            **kwargs):
        self.model.eval()
        timer = CountdownTimer(len(data))
        graphs = []
        orders = []
        smatch = 0
        for idx, batch in enumerate(data):
            graphs_per_batch = self.predict_amrs(batch)
            graphs_per_batch = [x[0] for x in graphs_per_batch]
            # Copy meta data from gold graph
            for gp, gg in zip(graphs_per_batch, batch['amr']):
                metadata = gg.metadata.copy()
                metadata['annotator'] = f'{self.config.transformer}-amr'
                metadata['date'] = str(datetime.datetime.now())
                if 'save-date' in metadata:
                    del metadata['save-date']
                gp.metadata = metadata
            graphs.extend(graphs_per_batch)
            orders.extend(batch[IDX])
            if idx == timer.total - 1:
                graphs = reorder(graphs, orders)
                write_predictions(output, self._tokenizer, graphs)
                try:
                    if use_fast:
                        smatch = compute_smatch(output, input)
                    else:
                        smatch = smatch_eval(output, input, use_fast=False)
                except:
                    pass
                timer.log(smatch.cstr() if isinstance(smatch, MetricDict) else f'{smatch:.2%}', ratio_percentage=False,
                          logger=logger)
            else:
                timer.log(ratio_percentage=False, logger=logger)

        return smatch
示例#13
0
    def predict(self,
                data: Union[str, List[str]],
                batch_size: int = None,
                tasks: Optional[Union[str, List[str]]] = None,
                skip_tasks: Optional[Union[str, List[str]]] = None,
                resolved_tasks=None,
                **kwargs) -> Document:
        """Predict on data.

        Args:
            data: A sentence or a list of sentences.
            batch_size: Decoding batch size.
            tasks: The tasks to predict.
            skip_tasks: The tasks to skip.
            resolved_tasks: The resolved tasks to override ``tasks`` and ``skip_tasks``.
            **kwargs: Not used.

        Returns:
            A :class:`~hanlp_common.document.Document`.
        """
        doc = Document()
        if not data:
            return doc

        target_tasks = resolved_tasks or self.resolve_tasks(tasks, skip_tasks)
        flatten_target_tasks = [
            self.tasks[t] for group in target_tasks for t in group
        ]
        cls_is_bos = any([x.cls_is_bos for x in flatten_target_tasks])
        sep_is_eos = any([x.sep_is_eos for x in flatten_target_tasks])
        # Now build the dataloaders and execute tasks
        first_task_name: str = list(target_tasks[0])[0]
        first_task: Task = self.tasks[first_task_name]
        encoder_transform, transform = self.build_transform(first_task)
        # Override the tokenizer config of the 1st task
        encoder_transform.sep_is_eos = sep_is_eos
        encoder_transform.cls_is_bos = cls_is_bos
        average_subwords = self.model.encoder.average_subwords
        flat = first_task.input_is_flat(data)
        if flat:
            data = [data]
        device = self.device
        samples = first_task.build_samples(data,
                                           cls_is_bos=cls_is_bos,
                                           sep_is_eos=sep_is_eos)
        dataloader = first_task.build_dataloader(samples,
                                                 transform=transform,
                                                 device=device)
        results = defaultdict(list)
        order = []
        for batch in dataloader:
            order.extend(batch[IDX])
            # Run the first task, let it make the initial batch for the successors
            output_dict = self.predict_task(first_task,
                                            first_task_name,
                                            batch,
                                            results,
                                            run_transform=True,
                                            cls_is_bos=cls_is_bos,
                                            sep_is_eos=sep_is_eos)
            # Run each task group in order
            for group_id, group in enumerate(target_tasks):
                # We could parallelize this in the future
                for task_name in group:
                    if task_name == first_task_name:
                        continue
                    output_dict = self.predict_task(self.tasks[task_name],
                                                    task_name,
                                                    batch,
                                                    results,
                                                    output_dict,
                                                    run_transform=True,
                                                    cls_is_bos=cls_is_bos,
                                                    sep_is_eos=sep_is_eos)
                if group_id == 0:
                    # We are kind of hard coding here. If the first task is a tokenizer,
                    # we need to convert the hidden and mask to token level
                    if first_task_name.startswith('tok'):
                        spans = []
                        tokens = []
                        for span_per_sent, token_per_sent in zip(
                                output_dict[first_task_name]['prediction'],
                                results[first_task_name][-len(batch[IDX]):]):
                            if cls_is_bos:
                                span_per_sent = [(-1, 0)] + span_per_sent
                                token_per_sent = [BOS] + token_per_sent
                            if sep_is_eos:
                                span_per_sent = span_per_sent + [
                                    (span_per_sent[-1][0] + 1,
                                     span_per_sent[-1][1] + 1)
                                ]
                                token_per_sent = token_per_sent + [EOS]
                            # The offsets start with 0 while [CLS] is zero
                            if average_subwords:
                                span_per_sent = [
                                    list(range(x[0] + 1, x[1] + 1))
                                    for x in span_per_sent
                                ]
                            else:
                                span_per_sent = [
                                    x[0] + 1 for x in span_per_sent
                                ]
                            spans.append(span_per_sent)
                            tokens.append(token_per_sent)
                        spans = PadSequenceDataLoader.pad_data(spans,
                                                               0,
                                                               torch.long,
                                                               device=device)
                        output_dict['hidden'] = pick_tensor_for_each_token(
                            output_dict['hidden'], spans, average_subwords)
                        batch['token_token_span'] = spans
                        batch['token'] = tokens
                        # noinspection PyTypeChecker
                        batch['token_length'] = torch.tensor(
                            [len(x) for x in tokens],
                            dtype=torch.long,
                            device=device)
                        batch.pop('mask', None)
        # Put results into doc in the order of tasks
        for k in self.config.task_names:
            v = results.get(k, None)
            if v is None:
                continue
            doc[k] = reorder(v, order)
        # Allow task to perform finalization on document
        for group in target_tasks:
            for task_name in group:
                task = self.tasks[task_name]
                task.finalize_document(doc, task_name)
        # If no tok in doc, use raw input as tok
        if not any(k.startswith('tok') for k in doc):
            doc['tok'] = data
        if flat:
            for k, v in list(doc.items()):
                doc[k] = v[0]
        # If there is only one field, don't bother to wrap it
        # if len(doc) == 1:
        #     return list(doc.values())[0]
        return doc