def test_issue10685(en_tokenizer): """Test `SpanGroups` de/serialization""" # Start with a Doc with no SpanGroups doc = en_tokenizer("Will it blend?") # Test empty `SpanGroups` de/serialization: assert len(doc.spans) == 0 doc.spans.from_bytes(doc.spans.to_bytes()) assert len(doc.spans) == 0 # Test non-empty `SpanGroups` de/serialization: doc.spans["test"] = SpanGroup(doc, name="test", spans=[doc[0:1]]) doc.spans["test2"] = SpanGroup(doc, name="test", spans=[doc[1:2]]) def assert_spangroups(): assert len(doc.spans) == 2 assert doc.spans["test"].name == "test" assert doc.spans["test2"].name == "test" assert list(doc.spans["test"]) == [doc[0:1]] assert list(doc.spans["test2"]) == [doc[1:2]] # Sanity check the currently-expected behavior assert_spangroups() # Now test serialization/deserialization: doc.spans.from_bytes(doc.spans.to_bytes()) assert_spangroups()
def doc(en_tokenizer): doc = en_tokenizer("0 1 2 3 4 5 6") matcher = Matcher(en_tokenizer.vocab, validate=True) # fmt: off matcher.add("4", [[{}, {}, {}, {}]]) matcher.add("2", [[ {}, {}, ]]) matcher.add("1", [[ {}, ]]) # fmt: on matches = matcher(doc) spans = [] for match in matches: spans.append( Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])) Random(42).shuffle(spans) doc.spans["SPANS"] = SpanGroup(doc, name="SPANS", attrs={"key": "value"}, spans=spans) return doc
def test_span_group_extend(doc): span_group_1 = doc.spans["SPANS"].copy() spans = [doc[0:5], doc[0:6]] span_group_2 = SpanGroup( doc, name="MORE_SPANS", attrs={ "key": "new_value", "new_key": "new_value" }, spans=spans, ) span_group_1_expected = span_group_1._concat(span_group_2) span_group_1.extend(span_group_2) assert len(span_group_1) == len(span_group_1_expected) assert span_group_1.attrs == {"key": "value", "new_key": "new_value"} assert list(span_group_1) == list(span_group_1_expected) span_group_1 = doc.spans["SPANS"] span_group_1.extend(spans) assert len(span_group_1) == len(span_group_1_expected) assert span_group_1.attrs == {"key": "value"} assert list(span_group_1) == list(span_group_1_expected)
def test_span_group_concat(doc, other_doc): span_group_1 = doc.spans["SPANS"] spans = [doc[0:5], doc[0:6]] span_group_2 = SpanGroup( doc, name="MORE_SPANS", attrs={ "key": "new_value", "new_key": "new_value" }, spans=spans, ) span_group_3 = span_group_1._concat(span_group_2) assert span_group_3.name == span_group_1.name assert span_group_3.attrs == {"key": "value", "new_key": "new_value"} span_list_expected = list(span_group_1) + list(span_group_2) assert list(span_group_3) == list(span_list_expected) # Inplace span_list_expected = list(span_group_1) + list(span_group_2) span_group_3 = span_group_1._concat(span_group_2, inplace=True) assert span_group_3 == span_group_1 assert span_group_3.name == span_group_1.name assert span_group_3.attrs == {"key": "value", "new_key": "new_value"} assert list(span_group_3) == list(span_list_expected) span_group_2 = other_doc.spans["SPANS"] with pytest.raises(ValueError): span_group_1._concat(span_group_2)
def test_doc_spans_setdefault(en_tokenizer): doc = en_tokenizer("Some text about Colombia and the Czech Republic") doc.spans.setdefault("key1") assert len(doc.spans["key1"]) == 0 doc.spans.setdefault("key2", default=[doc[0:1]]) assert len(doc.spans["key2"]) == 1 doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]])) assert len(doc.spans["key3"]) == 2
def test_span_groups_serialization_mismatches(en_tokenizer): """Test the serialization of multiple mismatching `SpanGroups` keys and `SpanGroup.name`s""" doc = en_tokenizer("How now, brown cow?") # Some variety: # 1 SpanGroup where its name matches its key # 2 SpanGroups that have the same name--which is not a key # 2 SpanGroups that have the same name--which is a key # 1 SpanGroup that is a value for 2 different keys (where its name is a key) # 1 SpanGroup that is a value for 2 different keys (where its name is not a key) groups = doc.spans groups["key1"] = SpanGroup(doc, name="key1", spans=[doc[0:1], doc[1:2]]) groups["key2"] = SpanGroup(doc, name="too", spans=[doc[3:4], doc[4:5]]) groups["key3"] = SpanGroup(doc, name="too", spans=[doc[1:2], doc[0:1]]) groups["key4"] = SpanGroup(doc, name="key4", spans=[doc[0:1]]) groups["key5"] = SpanGroup(doc, name="key4", spans=[doc[0:1]]) sg6 = SpanGroup(doc, name="key6", spans=[doc[0:1]]) groups["key6"] = sg6 groups["key7"] = sg6 sg8 = SpanGroup(doc, name="also", spans=[doc[1:2]]) groups["key8"] = sg8 groups["key9"] = sg8 regroups = SpanGroups(doc).from_bytes(groups.to_bytes()) # Assert regroups == groups assert regroups.keys() == groups.keys() for key, regroup in regroups.items(): # Assert regroup == groups[key] assert regroup.name == groups[key].name assert list(regroup) == list(groups[key])
def test_span_groups_serialization(en_tokenizer): doc = en_tokenizer("0 1 2 3 4 5 6") span_groups = SpanGroups(doc) spans = [doc[0:2], doc[1:3]] sg1 = SpanGroup(doc, spans=spans) span_groups["key1"] = sg1 span_groups["key2"] = sg1 span_groups["key3"] = [] reloaded_span_groups = SpanGroups(doc).from_bytes(span_groups.to_bytes()) assert span_groups.keys() == reloaded_span_groups.keys() for key, value in span_groups.items(): assert all(span == reloaded_span for span, reloaded_span in zip( span_groups[key], reloaded_span_groups[key]))
def test_span_group_add(doc): span_group_1 = doc.spans["SPANS"] spans = [doc[0:5], doc[0:6]] span_group_2 = SpanGroup( doc, name="MORE_SPANS", attrs={ "key": "new_value", "new_key": "new_value" }, spans=spans, ) span_group_3_expected = span_group_1._concat(span_group_2) span_group_3 = span_group_1 + span_group_2 assert len(span_group_3) == len(span_group_3_expected) assert span_group_3.attrs == {"key": "value", "new_key": "new_value"} assert list(span_group_3) == list(span_group_3_expected)