class SpacyRuleExtractor(Extractor):
    def __init__(self,
                 nlp,
                 rules: Dict,
                 extractor_name: str) -> None:
        """
        Initialize the extractor, storing the rule information and construct spacy rules
        Args:
            nlp
            rules: Dict
            extractor_name: str

        Returns:
        """

        Extractor.__init__(self,
                           input_type=InputType.TEXT,
                           category="spacy_rule_extractor",
                           name=extractor_name)
        self.rules = rules["rules"]
        self.nlp = copy.deepcopy(nlp)
        self.tokenizer = Tokenizer(self.nlp)
        self.matcher = Matcher(self.nlp.vocab)
        self.field_name = rules["field_name"]
        self.rule_lst = []
        for a_rule in self.rules:
            this_rule = Rule(a_rule, self.nlp)
            self.rule_lst.append(this_rule)

    def extract(self, text: str) -> List[Extraction]:
        """
        Extract from text
        Args:
            text: str

        Returns: List[Extraction]
        """
        doc = self.tokenizer.tokenize_to_spacy_doc(text)
        self.load_matcher()
        """TODO: add callback function to filter custom constrains
        1. Prefix, suffix
        2. min, max
        3. full shape
        4. is in output
        5. filter overlap
        """
        for idx, start, end in self.matcher(doc):
            print(idx, doc[start:end])

    def load_matcher(self) -> None:
        for idx, a_rule in enumerate(self.rule_lst):
            pattern_flat_lst = [a_pattern.spacy_token_lst for a_pattern in a_rule.patterns]
            for element in itertools.product(*pattern_flat_lst):
                x = list(element)
                print(x)
                self.matcher.add(idx, None, x)
Beispiel #2
0
class SpacyRuleExtractor(Extractor):
    def __init__(self, nlp, rules: Dict, extractor_name: str) -> None:
        """
        Initialize the extractor, storing the rule information and construct spacy rules
        Args:
            nlp
            rules (Dict): spacy rules
            extractor_name: str

        Returns:
        """

        Extractor.__init__(self,
                           input_type=InputType.TEXT,
                           category="spacy_rule_extractor",
                           name=extractor_name)
        self.rules = rules["rules"]
        self.nlp = copy.deepcopy(nlp)
        self.tokenizer = Tokenizer(self.nlp)
        self.matcher = Matcher(self.nlp.vocab)
        self.field_name = rules[
            "field_name"] if "field_name" in rules else extractor_name
        self.rule_lst = {}
        self.hash_map = {}
        for idx, a_rule in enumerate(self.rules):
            this_rule = Rule(a_rule, self.nlp)
            self.rule_lst[this_rule.identifier + "rule_id##" +
                          str(idx)] = this_rule

    def extract(self, text: str) -> List[Extraction]:
        """
        Extract from text
        Args:
            text: str

        Returns: List[Extraction]
        """

        doc = self.tokenizer.tokenize_to_spacy_doc(text)
        self.load_matcher()

        matches = [x for x in self.matcher(doc) if x[1] != x[2]]
        pos_filtered_matches = []
        neg_filtered_matches = []
        for idx, start, end in matches:
            span_doc = self.tokenizer.tokenize_to_spacy_doc(
                doc[start:end].text)
            this_spacy_rule = self.matcher.get(idx)
            relations = self.find_relation(span_doc, this_spacy_rule)
            rule_id, _ = self.hash_map[idx]
            this_rule = self.rule_lst[rule_id]
            if self.filter_match(doc[start:end], relations,
                                 this_rule.patterns):
                value = self.form_output(doc[start:end],
                                         this_rule.output_format, relations,
                                         this_rule.patterns)
                if this_rule.polarity:
                    pos_filtered_matches.append(
                        (start, end, value, rule_id, relations))
                else:
                    neg_filtered_matches.append(
                        (start, end, value, rule_id, relations))

        return_lst = []
        if pos_filtered_matches:
            longest_lst_pos = self.get_longest(pos_filtered_matches)
            if neg_filtered_matches:
                longest_lst_neg = self.get_longest(neg_filtered_matches)
                return_lst = self.reject_neg(longest_lst_pos, longest_lst_neg)
            else:
                return_lst = longest_lst_pos

        extractions = []
        for (start, end, value, rule_id, relation) in return_lst:
            this_extraction = Extraction(value=value,
                                         extractor_name=self.name,
                                         start_token=start,
                                         end_token=end,
                                         start_char=doc[start].idx,
                                         end_char=doc[end - 1].idx +
                                         len(doc[end - 1]),
                                         rule_id=rule_id.split("rule_id##")[0],
                                         match_mapping=relation)
            extractions.append(this_extraction)

        return extractions

    def load_matcher(self) -> None:
        """
        Add constructed spacy rule to Matcher
        """
        for id_key in self.rule_lst:
            if self.rule_lst[id_key].active:
                pattern_lst = [
                    a_pattern.spacy_token_lst
                    for a_pattern in self.rule_lst[id_key].patterns
                ]

                for spacy_rule_id, spacy_rule in enumerate(
                        itertools.product(*pattern_lst)):
                    self.matcher.add(self.construct_key(id_key, spacy_rule_id),
                                     None, list(spacy_rule))

    def filter_match(self, span: span, relations: Dict,
                     patterns: List) -> bool:
        """
        Filter the match result according to prefix, suffix, min, max ...
        Args:
            span: span
            relations: Dict
            patterns: List of pattern

        Returns: bool
        """

        for pattern_id, a_pattern in enumerate(patterns):
            token_range = relations[pattern_id]
            if token_range:
                tokens = [x for x in span[token_range[0]:token_range[1]]]
                if a_pattern.type == "word":
                    if not self.pre_suf_fix_filter(tokens, a_pattern.prefix,
                                                   a_pattern.suffix):
                        return False
                if a_pattern.type == "shape":
                    if not (self.full_shape_filter(tokens,
                                                   a_pattern.full_shape)
                            and self.pre_suf_fix_filter(
                                tokens, a_pattern.prefix, a_pattern.suffix)):
                        return False
                if a_pattern.type == "number":
                    if not self.min_max_filter(tokens, a_pattern.min,
                                               a_pattern.max):
                        return False
        return True

    @staticmethod
    def get_longest(value_lst: List) -> List:
        """
        Get the longest match for overlap
        Args:
            value_lst: List

        Returns: List
        """

        value_lst.sort()
        result = []
        pivot = value_lst[0]
        start, end = pivot[0], pivot[1]
        pivot_e = end
        pivot_s = start
        for idx, (s, e, v, rule_id, _) in enumerate(value_lst):
            if s == pivot_s and pivot_e < e:
                pivot_e = e
                pivot = value_lst[idx]
            elif s != pivot_s and pivot_e < e:
                result.append(pivot)
                pivot = value_lst[idx]
                pivot_e = e
                pivot_s = s
        result.append(pivot)
        return result

    @staticmethod
    def reject_neg(pos_lst: List, neg_lst: List) -> List:
        """
        Reject some positive matches according to negative matches
        Args:
            pos_lst: List
            neg_lst: List

        Returns: List
        """

        pos_lst.sort()
        neg_lst.sort()
        result = []
        pivot_pos = pos_lst[0]
        pivot_neg = neg_lst[0]
        while pos_lst:
            if pivot_pos[1] <= pivot_neg[0]:
                result.append(pivot_pos)
                pos_lst.pop(0)
                if pos_lst:
                    pivot_pos = pos_lst[0]
            elif pivot_pos[0] >= pivot_neg[1]:
                neg_lst.pop(0)
                if not neg_lst:
                    result += pos_lst
                    break
                else:
                    pivot_neg = neg_lst[0]
            else:
                pos_lst.pop(0)
                if pos_lst:
                    pivot_pos = pos_lst[0]
        return result

    @staticmethod
    def pre_suf_fix_filter(t: List, prefix: str, suffix: str) -> bool:
        """
        Prefix and Suffix filter
        Args:
            t: List, list of tokens
            prefix: str
            suffix: str

        Returns: bool
        """

        if prefix:
            for a_token in t:
                if a_token._.n_prefix(len(prefix)) != prefix:
                    return False
        if suffix:
            for a_token in t:
                if a_token._.n_suffix(len(suffix)) != suffix:
                    return False

        return True

    @staticmethod
    def min_max_filter(t: List, min_v: str, max_v: str) -> bool:
        """
        Min and Max filter
        Args:
            t: List, list of tokens
            min_v: str
            max_v: str

        Returns: bool
        """
        def tofloat(value):
            try:
                float(value)
                return float(value)
            except ValueError:
                return False

        for a_token in t:
            if not tofloat(a_token.text):
                return False
            else:
                if min_v and tofloat(min_v):
                    this_v = tofloat(a_token.text)
                    if this_v < tofloat(min_v):
                        return False
                if max_v and tofloat(max_v):
                    this_v = tofloat(a_token.text)
                    if this_v > tofloat(max_v):
                        return False

        return True

    @staticmethod
    def full_shape_filter(t: List, shapes: List) -> bool:
        """
        Shape filter
        Args:
            t: List, list of tokens
            shapes: List

        Returns: bool
        """

        if shapes:
            for a_token in t:
                if a_token._.full_shape not in shapes:
                    return False

        return True

    @staticmethod
    def form_output(span_doc: span, output_format: str, relations: Dict,
                    patterns: List) -> str:
        """
        Form an output value according to user input of output_format
        Args:
            span_doc: span
            format: str
            relations: Dict
            patterns: List

        Returns: str
        """

        format_value = []
        output_inf = [a_pattern.in_output for a_pattern in patterns]
        for i in range(len(output_inf)):
            token_range = relations[i]
            if token_range and output_inf[i]:
                format_value.append(
                    span_doc[token_range[0]:token_range[1]].text)

        if not output_format:
            return " ".join(format_value)

        result_str = ""
        s = list(output_format)
        t1 = s.pop(0)
        t2 = s.pop(0)
        while 1:
            t3 = s.pop(0)
            if t1 == '{' and t2.isdigit() and t3 == '}':
                if int(t2) > len(format_value):
                    return result_str + t1 + t2 + t3 + "".join(s)
                result_str += format_value[int(t2) - 1]
                if not s:
                    break
                t1 = s.pop(0)
                if not s:
                    result_str += t1
                    break
                t2 = s.pop(0)
                if not s:
                    result_str += t2
                    break
            else:
                result_str += t1
                t1 = t2
                t2 = t3
                if not s:
                    result_str += t1
                    result_str += t2
                    break
        return result_str

    def construct_key(self, rule_id: str, spacy_rule_id: int) -> int:
        """
        Use a mapping to store the information about rule_id for each matches, create the mapping key here
        Args:
            rule_id: str
            spacy_rule_id:int

        Returns: int
        """

        hash_key = (rule_id, spacy_rule_id)
        hash_v = hash(hash_key) + sys.maxsize + 1
        self.hash_map[hash_v] = hash_key
        return hash_v

    def find_relation(self, span_doc: doc, r: List) -> Dict:
        """
        Get the relations between the each pattern in the spacy rule and the matches
        Args:
            span_doc: doc
            r: List

        Returns: Dict
        """

        rule = r[1][0]
        span_pivot = 0
        relation = {}
        for e_id, element in enumerate(rule):
            if not span_doc[span_pivot:]:
                for extra_id, _, in enumerate(rule[e_id:]):
                    relation[e_id + extra_id] = None
                break
            new_doc = self.tokenizer.tokenize_to_spacy_doc(
                span_doc[span_pivot:].text)
            if "OP" not in element:
                relation[e_id] = (span_pivot, span_pivot + 1)
                span_pivot += 1
            else:
                if e_id < len(rule) - 1:
                    tmp_rule_1 = [rule[e_id]]
                    tmp_rule_2 = [rule[e_id + 1]]
                    tmp_matcher = Matcher(self.nlp.vocab)
                    tmp_matcher.add(0, None, tmp_rule_1)
                    tmp_matcher.add(1, None, tmp_rule_2)
                    tmp_matches = sorted(
                        [x for x in tmp_matcher(new_doc) if x[1] != x[2]],
                        key=lambda a: a[1])

                    if not tmp_matches:
                        relation[e_id] = None
                    else:
                        matches_1 = [
                            x for x in tmp_matches if x[0] == 0 and x[1] == 0
                        ]
                        if not matches_1:
                            relation[e_id] = None
                        else:
                            _, s1, e1 = matches_1[0]
                            matches_2 = [x for x in tmp_matches if x[0] == 1]
                            if not matches_2:
                                relation[e_id] = (span_pivot, span_pivot + e1)
                                span_pivot += e1
                            else:
                                _, s2, e2 = matches_2[0]
                                if e1 <= s2:
                                    relation[e_id] = (span_pivot,
                                                      span_pivot + e1)
                                    span_pivot += e1
                                else:
                                    relation[e_id] = (span_pivot,
                                                      span_pivot + s2)
                                    span_pivot += s2
                else:
                    relation[e_id] = (span_pivot, len(span_doc))

        return relation