def predict(self, sample: InputSample) -> List[str]: results = self.analyzer_engine.analyze( text=sample.full_text, entities=self.entities, language="en", score_threshold=self.score_threshold, ) starts = [] ends = [] scores = [] tags = [] # for res in results: starts.append(res.start) ends.append(res.end) tags.append(res.entity_type) scores.append(res.score) response_tags = span_to_tag( scheme=self.labeling_scheme, text=sample.full_text, start=starts, end=ends, tokens=sample.tokens, scores=scores, tag=tags, ) return response_tags
def predict(self, sample: InputSample) -> List[str]: sentence = Sentence(text=sample.full_text, use_tokenizer=self.spacy_tokenizer) self.model.predict(sentence) ents = sentence.get_spans("ner") if ents: tags, texts, start, end = zip(*[(ent.tag, ent.text, ent.start_pos, ent.end_pos) for ent in ents]) tags = [tag if tag != "PER" else "PERSON" for tag in tags] # Flair's tag for PERSON is PER # Flair tokens might not be consistent with spaCy's tokens (even when using spacy tokenizer) # Use spacy tokenization and not stanza to maintain consistency with other models: if not sample.tokens: sample.tokens = tokenize(sample.full_text) # Create tags (label per token) based on stanza spans and spacy tokens tags = span_to_tag( scheme="IO", text=sample.full_text, starts=start, ends=end, tags=tags, tokens=sample.tokens, ) else: tags = ["O" for _ in range(len(sample.tokens))] if len(tags) != len(sample.tokens): print("mismatch between input tokens and new tokens") return tags
def predict(self, sample: InputSample) -> List[str]: """ Predict the tags using a stanza model. :param sample: InputSample with text :return: list of tags """ doc = self.model(sample.full_text) if doc.ents: tags, texts, start, end = zip(*[(s.label_, s.text, s.start_char, s.end_char) for s in doc.ents]) # Stanza tokens might not be consistent with spaCy's tokens. # Use spacy tokenization and not stanza # to maintain consistency with other models: if not sample.tokens: sample.tokens = tokenize(sample.full_text) # Create tags (label per token) based on stanza spans and spacy tokens tags = span_to_tag( scheme=self.labeling_scheme, text=sample.full_text, starts=start, ends=end, tags=tags, tokens=sample.tokens, ) else: tags = ["O" for _ in range(len(sample.tokens))] if len(tags) != len(sample.tokens): print("mismatch between input tokens and new tokens") return tags
def predict(self, sample: InputSample) -> List[str]: if self.entities is None or len(self.entities) == 0: all_fields = True else: all_fields = None results = self.analyzer.analyze(sample.full_text, self.entities, language='en', all_fields=all_fields) starts = [] ends = [] scores = [] tags = [] # for res in results: # if res.score >= self.score_threshold: starts.append(res.start) ends.append(res.end) tags.append(res.entity_type) scores.append(res.score) # response_tags = span_to_tag(scheme=self.labeling_scheme, text=sample.full_text, start=starts, end=ends, tokens=sample.tokens, scores=scores, tag=tags) return response_tags
def test_span_to_biluo_adjecent_identical_entities(): text = "May I get access to Jessica Gump's account?" start = 20 end = 32 expected = ['O', 'O', 'O', 'O', 'O', 'B-PERSON', 'L-PERSON', 'O', 'O', 'O'] tag = ["PERSON"] biluo = span_to_tag(BILUO_SCHEME, text, [start], [end], tag) assert biluo == expected
def test_overlapping_entities_pyramid(): text = "My new phone number is 1 705 999 774 8720. Thanks, cya" start = [23, 25, 29] end = [41, 36, 32] scores = [0.6, 0.7, 0.8] tag = ["A1", "B2","C3"] expected = ['O', 'O', 'O', 'O', 'O', 'A1', 'B2', 'C3', 'B2', 'A1', 'O', 'O', 'O', 'O'] io = span_to_tag(scheme=IO_SCHEME, text=text, start=start, end=end, tag=tag, scores=scores) assert io == expected
def test_span_to_bilou_specific_input(): text = "Someone stole my credit card. The number is 5277716201469117 and " \ "the my name is Mary Anguiano" start = 80 end = 93 expected = ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PERSON', 'L-PERSON'] tag = ["PERSON"] bilou = span_to_tag(BILOU_SCHEME, text, [start], [end], tag) assert bilou == expected
def get_tags(self, scheme="IOB"): start_indices = [span.start_position for span in self.spans] end_indices = [span.end_position for span in self.spans] tags = [span.entity_type for span in self.spans] tokens = tokenize(self.full_text) labels = span_to_tag(scheme=scheme, text=self.full_text, tag=tags, start=start_indices, end=end_indices, tokens=tokens) return tokens, labels
def test_overlapping_entities_second_embedded_in_first_has_higher_score(): text = "My new phone number is 1 705 774 8720. Thanks, man" start = [23, 25] end = [37, 28] scores = [0.6, 0.7] tag = ["PHONE_NUMBER", "US_PHONE_NUMBER"] expected = ['O', 'O', 'O', 'O', 'O', 'PHONE_NUMBER', 'US_PHONE_NUMBER', 'PHONE_NUMBER', 'PHONE_NUMBER', 'O', 'O', 'O', 'O'] io = span_to_tag(scheme=IO_SCHEME, text=text, start=start, end=end, tag=tag, scores=scores) assert io == expected
def test_overlapping_entities_first_ends_in_mid_second(): text = "My new phone number is 1 705 774 8720. Thanks, man" start = [22, 25] end = [37, 37] scores = [0.6, 0.6] tag = ["PHONE_NUMBER", "US_PHONE_NUMBER"] expected = ['O', 'O', 'O', 'O', 'O', 'PHONE_NUMBER', 'US_PHONE_NUMBER', 'US_PHONE_NUMBER', 'US_PHONE_NUMBER', 'O', 'O', 'O', 'O'] io = span_to_tag(IO_SCHEME, text, start, end, tag, scores) assert io == expected
def test_span_to_biluo_single_at_end(): text = "My name is Josh" start = 11 end = 15 tag = "NAME" biluo = span_to_tag(BILUO_SCHEME, text, [start], [end], [tag]) print(biluo) expected = ['O', 'O', 'O', 'U-NAME'] assert biluo == expected
def test_span_to_biluo_multiple_tokens(): text = "My Address is 409 Bob st. Manhattan NY. I just moved in" start = 14 end = 38 tag = "ADDRESS" biluo = span_to_tag(BILUO_SCHEME, text, [start], [end], [tag]) expected = [ 'O', 'O', 'O', 'B-ADDRESS', 'I-ADDRESS', 'I-ADDRESS', 'I-ADDRESS', 'I-ADDRESS', 'L-ADDRESS', 'O', 'O', 'O', 'O', 'O' ] assert biluo == expected
def get_tags(self, scheme="IOB", model_version="en_core_web_sm"): start_indices = [span.start_position for span in self.spans] end_indices = [span.end_position for span in self.spans] tags = [span.entity_type for span in self.spans] tokens = tokenize(self.full_text, model_version) labels = span_to_tag( scheme=scheme, text=self.full_text, tags=tags, starts=start_indices, ends=end_indices, tokens=tokens, ) return tokens, labels
def test_token_contains_span(): # The last token here (https://www.gmail.com/) contains the span (www.gmail.com). # In this case the token should be tagged as the span tag, even if not all of it is covered by the span. text = "My website is https://www.gmail.com/" start = [22] end = [35] scores = [1.0] tag = ["DOMAIN_NAME"] expected = ["O", "O", "O", "DOMAIN_NAME"] io = span_to_tag(scheme=IO_SCHEME, text=text, starts=start, ends=end, tags=tag, scores=scores) assert io == expected
def test_span_to_biluo_multiple_entities(): text = "My name is Josh or David" start1 = 11 end1 = 15 start2 = 19 end2 = 26 start = [start1, start2] end = [end1, end2] tag = ["NAME", "NAME"] biluo = span_to_tag(BILUO_SCHEME, text, start, end, tag) print(biluo) expected = ['O', 'O', 'O', 'U-NAME', 'O', 'U-NAME'] assert biluo == expected
def test_span_to_biluo_adjacent_entities(): text = "Mr. Tree" start1 = 0 end1 = 2 start2 = 4 end2 = 8 start = [start1, start2] end = [end1, end2] tag = ["TITLE", "NAME"] biluo = span_to_tag(BILUO_SCHEME, text, start, end, tag) print(biluo) expected = ['U-TITLE', 'U-NAME'] assert biluo == expected
def test_overlapping_entities_second_embedded_in_first_with_lower_score(): text = "My new phone number is 1 705 774 8720. Thanks, man" start = [22, 25] end = [37, 33] scores = [0.6, 0.5] tag = ["PHONE_NUMBER", "US_PHONE_NUMBER"] expected = [ 'O', 'O', 'O', 'O', 'O', 'PHONE_NUMBER', 'PHONE_NUMBER', 'PHONE_NUMBER', 'PHONE_NUMBER', 'O', 'O', 'O', 'O' ] io = span_to_tag(BIO_SCHEME, text, start, end, tag, scores, io_tags_only=True) assert io == expected
def test_span_to_bio_multiple_entities(): text = "My name is Josh or David" start1 = 11 end1 = 15 start2 = 19 end2 = 26 start = [start1, start2] end = [end1, end2] tag = ["NAME", "NAME"] bilou = span_to_tag(scheme=BIO_SCHEME, text=text, start=start, end=end, tag=tag) print(bilou) expected = ['O', 'O', 'O', 'I-NAME', 'O', 'I-NAME'] assert bilou == expected