Example #1
0
 def post_outputs(self, predictions, data, order, use_pos, build_data):
     predictions = reorder(predictions, order)
     if build_data:
         data = reorder(data, order)
     outputs = []
     self.predictions_to_human(predictions, outputs, data, use_pos)
     return outputs
Example #2
0
 def predict(self,
             data: Union[str, List[str]],
             beautiful_amr_graph=True,
             **kwargs):
     flat = self.input_is_flat(data)
     if flat:
         data = [data]
     dataloader = self.build_dataloader([{
         'text': x
     } for x in data],
                                        **self.config,
                                        device=self.device)
     orders = []
     results = []
     for batch in dataloader:
         graphs = self.predict_amrs(batch)
         graphs = [x[0] for x in graphs]
         if beautiful_amr_graph:
             graphs = [
                 AMRGraph(x.triples, x.top, x.epidata, x.metadata)
                 for x in graphs
             ]
         results.extend(graphs)
         orders.extend(batch[IDX])
     results = reorder(results, orders)
     if flat:
         results = results[0]
     return results
Example #3
0
 def predict(self,
             data: Union[str, List[str]],
             batch_size: int = None,
             **kwargs):
     if not data:
         return []
     flat = self.input_is_flat(data)
     if flat:
         data = [data]
     samples = self.build_samples(data)
     dataloader = self.build_dataloader(samples,
                                        device=self.device,
                                        **merge_dict(self.config,
                                                     batch_size=batch_size,
                                                     overwrite=True))
     outputs = []
     orders = []
     for idx, batch in enumerate(dataloader):
         out, mask = self.feed_batch(batch)
         prediction = self.decode_output(out, mask, batch, span_probs=None)
         # prediction = [x[0] for x in prediction]
         outputs.extend(prediction)
         orders.extend(batch[IDX])
     outputs = reorder(outputs, orders)
     if flat:
         return outputs[0]
     return outputs
Example #4
0
 def predict(self,
             data: Union[str, List[str]],
             batch_size: int = None,
             fmt='dict',
             **kwargs):
     if not data:
         return []
     flat = self.input_is_flat(data)
     if flat:
         data = [data]
     samples = []
     for token in data:
         sample = dict()
         sample['token'] = token
         samples.append(sample)
     batch_size = batch_size or self.config.batch_size
     dataloader = self.build_dataloader(samples,
                                        batch_size,
                                        False,
                                        self.device,
                                        None,
                                        generate_idx=True)
     outputs = []
     order = []
     for batch in dataloader:
         output_dict = self.feed_batch(batch)
         outputs.extend(output_dict['prediction'])
         order.extend(batch[IDX])
     outputs = reorder(outputs, order)
     if fmt == 'list':
         outputs = self.format_dict_to_results(data, outputs)
     if flat:
         return outputs[0]
     return outputs
Example #5
0
    def evaluate_dataloader(self,
                            data: DataLoader,
                            criterion: Callable,
                            metric=None,
                            output=False,
                            ratio_width=None,
                            logger=None,
                            input=None,
                            use_fast=False,
                            **kwargs):
        self.model.eval()
        timer = CountdownTimer(len(data))
        graphs = []
        orders = []
        smatch = 0
        for idx, batch in enumerate(data):
            graphs_per_batch = self.predict_amrs(batch)
            graphs_per_batch = [x[0] for x in graphs_per_batch]
            # Copy meta data from gold graph
            for gp, gg in zip(graphs_per_batch, batch['amr']):
                metadata = gg.metadata.copy()
                metadata['annotator'] = f'{self.config.transformer}-amr'
                metadata['date'] = str(datetime.datetime.now())
                if 'save-date' in metadata:
                    del metadata['save-date']
                gp.metadata = metadata
            graphs.extend(graphs_per_batch)
            orders.extend(batch[IDX])
            if idx == timer.total - 1:
                graphs = reorder(graphs, orders)
                write_predictions(output, self._tokenizer, graphs)
                try:
                    if use_fast:
                        smatch = compute_smatch(output, input)
                    else:
                        smatch = smatch_eval(output, input, use_fast=False)
                except:
                    pass
                timer.log(smatch.cstr() if isinstance(smatch, MetricDict) else
                          f'{smatch:.2%}',
                          ratio_percentage=False,
                          logger=logger)
            else:
                timer.log(ratio_percentage=False, logger=logger)

        return smatch
Example #6
0
 def predict_data(self, data, batch_size, **kwargs):
     samples = self.build_samples(data, **kwargs)
     if not batch_size:
         batch_size = self.config.get('batch_size', 32)
     dataloader = self.build_dataloader(samples, batch_size, False,
                                        self.device)
     outputs = []
     orders = []
     vocab = self.vocabs['tag'].idx_to_token
     for batch in dataloader:
         out, mask = self.feed_batch(batch)
         pred = self.decode_output(out, mask, batch)
         if isinstance(pred, torch.Tensor):
             pred = pred.tolist()
         outputs.extend(self.prediction_to_human(pred, vocab, batch))
         orders.extend(batch[IDX])
     outputs = reorder(outputs, orders)
     return outputs
Example #7
0
 def predict(self, data: Union[List[str], List[List[str]]], batch_size: int = None, ret_tokens=True, **kwargs):
     if not data:
         return []
     flat = self.input_is_flat(data)
     if flat:
         data = [data]
     dataloader = self.build_dataloader([{'token': x} for x in data], batch_size, False, self.device)
     predictions = []
     orders = []
     for batch in dataloader:
         output_dict = self.feed_batch(batch)
         token = batch['token']
         prediction = output_dict['prediction']
         self.prediction_to_result(token, prediction, predictions, ret_tokens)
         orders.extend(batch[IDX])
     predictions = reorder(predictions, orders)
     if flat:
         return predictions[0]
     return predictions
Example #8
0
def parse_data(model,
               pp: PostProcessor,
               data,
               input_file,
               output_file,
               beam_size=8,
               alpha=0.6,
               max_time_step=100,
               h=None):
    if not output_file:
        output_file = tempfile.NamedTemporaryFile().name
    tot = 0
    levi_graph = model.decoder.levi_graph if hasattr(model,
                                                     'decoder') else False
    with open(output_file, 'w') as fo:
        timer = CountdownTimer(len(data))
        order = []
        outputs = []
        for batch in data:
            order.extend(batch[IDX])
            res = parse_batch(model,
                              batch,
                              beam_size,
                              alpha,
                              max_time_step,
                              h=h)
            outputs.extend(
                list(zip(res['concept'], res['relation'], res['score'])))
            timer.log('Parsing [blink][yellow]...[/yellow][/blink]',
                      ratio_percentage=False)
        outputs = reorder(outputs, order)
        timer = CountdownTimer(len(data))
        for concept, relation, score in outputs:
            fo.write('# ::conc ' + ' '.join(concept) + '\n')
            fo.write('# ::score %.6f\n' % score)
            fo.write(
                pp.postprocess(concept, relation, check_connected=levi_graph) +
                '\n\n')
            tot += 1
            timer.log('Post-processing [blink][yellow]...[/yellow][/blink]',
                      ratio_percentage=False)
    match(output_file, input_file)
Example #9
0
def parse_data_(model,
                pp: PostProcessor,
                data,
                beam_size=8,
                alpha=0.6,
                max_time_step=100,
                h=None):
    levi_graph = model.decoder.levi_graph if hasattr(model,
                                                     'decoder') else False
    if levi_graph:
        raise NotImplementedError('Only supports Graph Transducer')
    order = []
    outputs = []
    for batch in data:
        order.extend(batch[IDX])
        res = parse_batch(model, batch, beam_size, alpha, max_time_step, h=h)
        outputs.extend(list(zip(res['concept'], res['relation'],
                                res['score'])))
    outputs = reorder(outputs, order)
    for concept, relation, score in outputs:
        yield pp.to_amr(concept, relation)
Example #10
0
    def predict(self,
                data: Union[str, List[str]],
                batch_size: int = None,
                tasks: Optional[Union[str, List[str]]] = None,
                resolve_dependencies=True,
                **kwargs) -> Document:
        doc = Document()
        if not data:
            return doc

        if resolve_dependencies:
            # Now we decide which tasks to perform and their orders
            tasks_in_topological_order = self._tasks_in_topological_order
            task_topological_order = self._task_topological_order
            computation_graph = self._computation_graph
            target_tasks = self._resolve_task_name(tasks)
            if not target_tasks:
                target_tasks = tasks_in_topological_order
            else:
                target_topological_order = defaultdict(set)
                for task_name in target_tasks:
                    if task_name not in computation_graph:
                        continue
                    for dependency in topological_sort(computation_graph,
                                                       task_name):
                        target_topological_order[
                            task_topological_order[dependency]].add(dependency)
                target_tasks = [
                    item[1]
                    for item in sorted(target_topological_order.items())
                ]
        else:
            target_tasks = [set(tasks)] if isinstance(tasks,
                                                      list) else [{tasks}]
        if not target_tasks:
            return Document()
        # Sort target tasks within the same group in a defined order
        target_tasks = [
            sorted(x, key=lambda _x: self.config.task_names.index(_x))
            for x in target_tasks
        ]
        flatten_target_tasks = [
            self.tasks[t] for group in target_tasks for t in group
        ]
        cls_is_bos = any([x.cls_is_bos for x in flatten_target_tasks])
        sep_is_eos = any([x.sep_is_eos for x in flatten_target_tasks])
        # Now build the dataloaders and execute tasks
        first_task_name: str = list(target_tasks[0])[0]
        first_task: Task = self.tasks[first_task_name]
        encoder_transform, transform = self.build_transform(first_task)
        # Override the tokenizer config of the 1st task
        encoder_transform.sep_is_eos = sep_is_eos
        encoder_transform.cls_is_bos = cls_is_bos
        average_subwords = self.model.encoder.average_subwords
        flat = first_task.input_is_flat(data)
        if flat:
            data = [data]
        device = self.device
        samples = first_task.build_samples(data,
                                           cls_is_bos=cls_is_bos,
                                           sep_is_eos=sep_is_eos)
        dataloader = first_task.build_dataloader(samples,
                                                 transform=transform,
                                                 device=device)
        results = defaultdict(list)
        order = []
        for batch in dataloader:
            order.extend(batch[IDX])
            # Run the first task, let it make the initial batch for the successors
            output_dict = self.predict_task(first_task,
                                            first_task_name,
                                            batch,
                                            results,
                                            run_transform=True,
                                            cls_is_bos=cls_is_bos,
                                            sep_is_eos=sep_is_eos)
            # Run each task group in order
            for group_id, group in enumerate(target_tasks):
                # We could parallelize this in the future
                for task_name in group:
                    if task_name == first_task_name:
                        continue
                    output_dict = self.predict_task(self.tasks[task_name],
                                                    task_name,
                                                    batch,
                                                    results,
                                                    output_dict,
                                                    run_transform=True,
                                                    cls_is_bos=cls_is_bos,
                                                    sep_is_eos=sep_is_eos)
                if group_id == 0:
                    # We are kind of hard coding here. If the first task is a tokenizer,
                    # we need to convert the hidden and mask to token level
                    if 'token_token_span' not in batch:
                        spans = []
                        tokens = []
                        for span_per_sent, token_per_sent in zip(
                                output_dict[first_task_name]['prediction'],
                                results[first_task_name]):
                            if cls_is_bos:
                                span_per_sent = [(-1, 0)] + span_per_sent
                                token_per_sent = [BOS] + token_per_sent
                            if sep_is_eos:
                                span_per_sent = span_per_sent + [
                                    (span_per_sent[-1][0] + 1,
                                     span_per_sent[-1][1] + 1)
                                ]
                                token_per_sent = token_per_sent + [EOS]
                            # The offsets start with 0 while [CLS] is zero
                            if average_subwords:
                                span_per_sent = [
                                    list(range(x[0] + 1, x[1] + 1))
                                    for x in span_per_sent
                                ]
                            else:
                                span_per_sent = [
                                    x[0] + 1 for x in span_per_sent
                                ]
                            spans.append(span_per_sent)
                            tokens.append(token_per_sent)
                        spans = PadSequenceDataLoader.pad_data(spans,
                                                               0,
                                                               torch.long,
                                                               device=device)
                        output_dict['hidden'] = pick_tensor_for_each_token(
                            output_dict['hidden'], spans, average_subwords)
                        batch['token_token_span'] = spans
                        batch['token'] = tokens
                        # noinspection PyTypeChecker
                        batch['token_length'] = torch.tensor(
                            [len(x) for x in tokens],
                            dtype=torch.long,
                            device=device)
                        batch.pop('mask', None)
        # Put results into doc in the order of tasks
        for k in self.config.task_names:
            v = results.get(k, None)
            if v is None:
                continue
            doc[k] = reorder(v, order)
        # Allow task to perform finalization on document
        for group in target_tasks:
            for task_name in group:
                task = self.tasks[task_name]
                task.finalize_document(doc, task_name)
        # If no tok in doc, use raw input as tok
        if not any(k.startswith('tok') for k in doc):
            doc['tok'] = data
        if flat:
            for k, v in list(doc.items()):
                doc[k] = v[0]
        # If there is only one field, don't bother to wrap it
        # if len(doc) == 1:
        #     return list(doc.values())[0]
        return doc