示例#1
0
    def _extract_features(self,
                          doc: Document,
                          ent: Entity,
                          include_labels=False):
        ent_sent_idx = doc.get_entity_sent_idx(ent)

        start_token = doc.sentences[ent_sent_idx].start_token
        end_token = doc.sentences[ent_sent_idx].end_token

        features = {
            **self.token_feature_extractor.extract_features_from_doc(
                doc, start_token, end_token),
            **self._get_attention_features(doc, ent, start_token, end_token),
            **self._get_classifier_features(doc, ent)
        }

        if include_labels:
            label = _map_ne(doc, ent)
            features["labels"] = self.label_converter[label]

        ent_mask = [0] * len(self.label_converter)
        for key in self.types_mapping[ent.type]:
            ent_mask[self.label_converter[key]] = 1

        features["labels_mask"] = ent_mask

        features["indices"] = [[
            ent.start_token - start_token, ent.end_token - start_token
        ]]

        return features
示例#2
0
    def _get_attention_features(self, doc: Document, ent: Entity,
                                start_token: int, end_token: int) -> dict:

        attention_features = {}

        wrt_span = (ent.start_token, ent.end_token,
                    doc.get_entity_sent_idx(ent))
        position_features = self.token_position_fe.extract_features_from_doc(
            doc, start_token, end_token, wrt_span)
        attention_features.update(position_features)

        attention_features.update(
            self._get_entity_features(doc, ent,
                                      self.attention_features_converters,
                                      "attention"))

        return attention_features
示例#3
0
    def _get_attention_features(self, doc: Document, e1: Entity, e2: Entity,
                                start_token: int, end_token: int) -> dict:

        attention_features = {}
        for namespace, ent in zip(['e1', 'e2'], [e1, e2]):
            wrt_span = (ent.start_token, ent.end_token,
                        doc.get_entity_sent_idx(ent))
            position_features = self.token_position_fe.extract_features_from_doc(
                doc, start_token, end_token, wrt_span)
            attention_features.update(namespaced(position_features, namespace))

        attention_features.update(
            self._get_relation_features(doc, e1, e2,
                                        self.attention_features_converters,
                                        "attention"))

        return attention_features
示例#4
0
    def _extract_features(self,
                          doc: Document,
                          e1: Entity,
                          e2: Entity,
                          rel_type: str,
                          *,
                          include_labels=False):
        e1_sent_idx = doc.get_entity_sent_idx(e1)
        e2_sent_idx = doc.get_entity_sent_idx(e2)

        start_token = doc.sentences[min(e1_sent_idx, e2_sent_idx)].start_token
        end_token = doc.sentences[max(e1_sent_idx, e2_sent_idx)].end_token

        e1_wrt_span = (e1.start_token, e1.end_token, e1_sent_idx)
        e2_wrt_span = (e2.start_token, e2.end_token, e2_sent_idx)

        features = {
            **self.shared_feature_extractor.extract_features_from_doc(
                doc, start_token, end_token, e1_wrt_span, e2_wrt_span),
            **self._get_entities_encoder_features(doc, start_token, end_token),
            **self._get_attention_features(doc, e1, e2, start_token, end_token),
            **self._get_classifier_features(doc, e1, e2)
        }

        if include_labels:
            features["labels"] = self.rel_converter[rel_type]

        rel_mask = [0] * len(self.rel_converter)
        rel_mask[self.rel_converter[None]] = 1
        for key in self.valid_ent_rel_types[e1.type, e2.type]:
            rel_mask[self.rel_converter[key]] = 1
        features["labels_mask"] = rel_mask

        features["indices"] = [[
            e1.start_token - start_token, e1.end_token - start_token
        ], [e2.start_token - start_token, e2.end_token - start_token]]

        return features
示例#5
0
文件: helper.py 项目: wayne9qiu/derek
def get_sentence_distance_between_entities(doc: Document, e1: Entity,
                                           e2: Entity):
    return abs(doc.get_entity_sent_idx(e1) - doc.get_entity_sent_idx(e2))