예제 #1
0
    def run_taskB_on_seqs(self, dataset, collection: Collection, *args,
                          **kargs):
        model = self.taskB_seq_model
        if model is None:
            return

        with torch.no_grad():
            for features, i, (sid, head_id, tokens_ids) in tqdm(
                    dataset.shallow_dataloader(),
                    total=len(dataset),
                    desc="Relations (Sequence)",
            ):
                output = model((features, i))
                output = model.decode(output)
                labels = [dataset.labels[x] for x in output]

                sentence = collection.sentences[sid]
                head_entity = sentence.keyphrases[head_id]
                for token_id, label in zip(tokens_ids, labels):
                    if label is None or token_id < 0:
                        continue

                    token_entity = sentence.keyphrases[token_id]

                    rel_origin = head_entity.id
                    rel_destination = token_entity.id
                    relation = Relation(sentence, rel_origin, rel_destination,
                                        label)
                    sentence.relations.append(relation)
예제 #2
0
    def process_output_sentence(self, sentence, prediction,
                                token_pairs) -> List[Relation]:
        predicted_tags = [
            self.idx2rel.get(p) for pred in prediction for p in pred
        ]
        #print(predicted_tags)
        list_of_relations = []
        for pair, tag in zip(token_pairs, predicted_tags):
            token1 = pair[0].get('text')
            span1 = pair[0]['span']
            token2 = pair[1].get('text')
            span2 = pair[1]['span']

            if tag != 'none' and token1 != token2 \
                    and not token1.startswith('##') and not token2.startswith('##') \
                    and token1 not in ['[PAD]','[SEP]','[CLS]'] and token2 not in ['[PAD]','[SEP]','[CLS]'] :
                id_kph1 = findKeyphraseId(sentence, span1)
                id_kph2 = findKeyphraseId(sentence, span2)
                # If NER module didnt identify an entity,  ID = None
                # Also, if the pair of tokens belong to the same entity, they have same ID -> we remove this case
                if id_kph1 and id_kph2 and id_kph1 != id_kph2:
                    relation = Relation(sentence=sentence,
                                        origin=id_kph1,
                                        destination=id_kph2,
                                        label=tag)
                    list_of_relations.append(relation)
        return list_of_relations
예제 #3
0
    def run_taskB_on_pairs(self, dataset, collection: Collection, *args,
                           **kargs):
        model = self.taskB_pair_model
        if model is None:
            return

        with torch.no_grad():
            for *features, (sid, s_id, d_id) in tqdm(
                    dataset.shallow_dataloader(),
                    total=len(dataset),
                    desc="Relations (Pairs)",
            ):
                s_id = s_id.item()
                d_id = d_id.item()

                output = model(features).squeeze(0)
                output = output.argmax(dim=-1)
                label = dataset.labels[output.item()]

                if label is None:
                    continue

                sentence = collection.sentences[sid]
                rel_origin = sentence.keyphrases[s_id].id
                rel_destination = sentence.keyphrases[d_id].id

                relation = Relation(sentence, rel_origin, rel_destination,
                                    label)
                sentence.relations.append(relation)
예제 #4
0
    def load(cls, collection: Collection, finput: Path):
        input_b_file = finput.parent / ("output_b_" + finput.name.split("_")[1])

        sentence_by_id = cls.load_keyphrases(collection, finput)

        for line in input_b_file.open(encoding="utf8").readlines():
            label, src, dst = line.strip().split("\t")
            src, dst = int(src), int(dst)

            the_sentence = sentence_by_id[src]

            if the_sentence != sentence_by_id[dst]:
                warnings.warn(
                    "In file '%s' relation '%s' between %i and %i crosses sentence boundaries and has been ignored."
                    % (finput, label, src, dst)
                )
                continue

            assert sentence_by_id[dst] == the_sentence

            the_sentence.relations.append(
                Relation(the_sentence, src, dst, label.lower())
            )

        return collection
예제 #5
0
    def predict_relations(self, sentence):

        for origin in sentence.keyphrases:
            origin_text = origin.text.lower()
            for destination in sentence.keyphrases:
                destination_text = destination.text.lower()
                try:
                    label = self.relations[origin_text, origin.label, destination_text, destination.label]
                except KeyError:
                    continue
                relation = Relation(sentence, origin.id, destination.id, label)
                sentence.relations.append(relation)
예제 #6
0
    def predict_relation_single(self, doc, sentence):
        # predecir la relación más probable para cada par de palabras
        for k1 in sentence.keyphrases:
            for k2 in sentence.keyphrases:
                if k1 == k2:
                    continue

                # k1 y k2 son Keyphrases, convertir a features
                features = self.relation_features(None, k1, k2, doc)
                if features is None:
                    continue

                relation_label = self.relation_classifier.predict([features
                                                                   ])[0]

                if not relation_label:
                    continue

                relation = Relation(sentence, k1.id, k2.id, relation_label)
                probs = self.relation_classifier.predict_proba([features])[0]
                relation.uncertainty = scipy.stats.entropy(list(probs), base=2)
                sentence.relations.append(relation)
예제 #7
0
    def run(self, collection, taskA, taskB):
        gold_keyphrases, gold_relations = self.model

        if taskA:
            next_id = 0
            for gold_keyphrase, label in gold_keyphrases.items():
                for sentence in collection.sentences:
                    text = sentence.text.lower()
                    pattern = r"\b" + gold_keyphrase + r"\b"
                    for match in re.finditer(pattern, text):
                        keyphrase = Keyphrase(sentence, label, next_id,
                                              [match.span()])
                        keyphrase.split()
                        next_id += 1

                        sentence.keyphrases.append(keyphrase)

        if taskB:
            for sentence in collection.sentences:
                for origin in sentence.keyphrases:
                    origin_text = origin.text.lower()
                    for destination in sentence.keyphrases:
                        destination_text = destination.text.lower()
                        try:
                            label = gold_relations[origin_text, origin.label,
                                                   destination_text,
                                                   destination.label, ]
                        except KeyError:
                            continue
                        relation = Relation(sentence, origin.id,
                                            destination.id, label)
                        sentence.relations.append(relation)

                sentence.remove_dup_relations()

        return collection