예제 #1
0
def decode(trig_dict, arg_dict, vocab):
    """
    Largely copy-pasted from what happens in dygie.
    """
    ignore = ["loss", "decoded_events"]
    trigs = fields_to_batches(
        {k: v.detach().cpu()
         for k, v in trig_dict.items() if k not in ignore})
    args = fields_to_batches(
        {k: v.detach().cpu()
         for k, v in arg_dict.items() if k not in ignore})

    res = []

    # Collect predictions for each sentence in minibatch.
    for trig, arg in zip(trigs, args):
        decoded_trig = decode_trigger(trig, vocab)
        decoded_args, decoded_args_with_scores = decode_arguments(
            arg, decoded_trig, vocab)
        entry = dict(trigger_dict=decoded_trig,
                     argument_dict=decoded_args,
                     argument_dict_with_scores=decoded_args_with_scores)
        res.append(entry)

    return res
예제 #2
0
 def __init__(self, js):
     self._doc_key = js["doc_key"]
     entries = fields_to_batches(
         js,
         ["doc_key", "clusters", "predicted_clusters", "section_starts"])
     sentence_lengths = [len(entry["sentences"]) for entry in entries]
     sentence_starts = np.cumsum(sentence_lengths)
     sentence_starts = np.roll(sentence_starts, 1)
     sentence_starts[0] = 0
     self.sentence_starts = sentence_starts
     self.sentences = [
         Sentence(entry, sentence_start, sentence_ix)
         for sentence_ix, (
             entry,
             sentence_start) in enumerate(zip(entries, sentence_starts))
     ]
     if "clusters" in js:
         self.clusters = [
             Cluster(entry, i, self)
             for i, entry in enumerate(js["clusters"])
         ]
     if "predicted_clusters" in js:
         self.predicted_clusters = [
             Cluster(entry, i, self)
             for i, entry in enumerate(js["predicted_clusters"])
         ]
예제 #3
0
파일: events.py 프로젝트: MSLars/mare
    def predict(self, output_dict, document):
        """
        Take the output and convert it into a list of dicts. Each entry is a sentence. Each key is a
        pair of span indices for that sentence, and each value is the relation label on that span
        pair.
        """
        # debatch output to sentence
        outputs = fields_to_batches(
            {k: v.detach().cpu()
             for k, v in output_dict.items()})

        prediction_dicts = []
        predictions = []

        # Collect predictions for each sentence in minibatch.
        for output, sentence in zip(outputs, document):
            decoded_trig = self._decode_trigger(output)
            decoded_args = self._decode_arguments(output, decoded_trig)
            predicted_events = self._assemble_predictions(
                decoded_trig, decoded_args, sentence)
            prediction_dicts.append({
                "trigger_dict": decoded_trig,
                "argument_dict": decoded_args
            })
            predictions.append(predicted_events)

        return prediction_dicts, predictions
예제 #4
0
    def from_json(cls, js):
        """
        Read in from json-loaded dict.
        :params js dict
        return Document obj
        """
        # check_fields
        cls._check_fields(js)
        doc_key = js["doc_key"]
        dataset = js.get("dataset")
        entries = fields_to_batches(
            js,
            ["doc_key", "dataset", "clusters", "predicted_clusters", "weight"])
        sentence_lengths = [len(entry["sentences"]) for entry in entries]
        sentence_starts = np.cumsum(sentence_lengths)
        sentence_starts = np.roll(sentence_starts, 1)
        sentence_starts[0] = 0
        sentence_starts = sentence_starts.tolist()
        sentences = [
            Sentence(entry, sentence_start, sentence_ix)
            for sentence_ix, (
                entry,
                sentence_start) in enumerate(zip(entries, sentence_starts))
        ]
        # Store cofereference annotations.
        if "clusters" in js:
            clusters = [
                Cluster(entry, i, sentences, sentence_starts)
                for i, entry in enumerate(js["clusters"])
            ]
        else:
            clusters = None
        # TODO(dwadden) Need to treat predicted clusters differently and update sentences
        # appropriately.
        if "predicted_clusters" in js:
            predicted_clusters = [
                Cluster(entry, i, sentences, sentence_starts)
                for i, entry in enumerate(js["predicted_clusters"])
            ]
        else:
            predicted_clusters = None

        # Update the sentences with coreference cluster labels.
        sentences = update_sentences_with_clusters(sentences, clusters)

        # Get the loss weight for this document.
        weight = js.get("weight", None)

        return cls(doc_key, dataset, sentences, clusters, predicted_clusters,
                   weight)
예제 #5
0
    def decode(self, output_dict):
        """
        Take the output and convert it into a list of dicts. Each entry is a sentence. Each key is a
        pair of span indices for that sentence, and each value is the relation label on that span
        pair.
        """
        outputs = fields_to_batches({k: v.detach().cpu() for k, v in output_dict.items()})

        res = []

        # Collect predictions for each sentence in minibatch.
        for output in outputs:
            decoded_trig = self._decode_trigger(output)
            decoded_args, decoded_args_with_scores = self._decode_arguments(output, decoded_trig)
            entry = dict(trigger_dict=decoded_trig, argument_dict=decoded_args,
                         argument_dict_with_scores=decoded_args_with_scores)
            res.append(entry)

        output_dict["decoded_events"] = res
        return output_dict