def test_unnest(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=16, text="here"), entity_id="Q1"), schema.Mention( example_id="30ba", mention_span=schema.TextSpan(start=30, end=35, text="World"), entity_id="Q1", ), ]) context_mentions.validate() expected = [ schema.ContextualMention( context=self._CONTEXT_A, mention=schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=16, text="here"), entity_id="Q1"), ), schema.ContextualMention( context=self._CONTEXT_A, mention=schema.Mention( example_id="30ba", mention_span=schema.TextSpan(start=30, end=35, text="World"), entity_id="Q1", ), ) ] self.assertEqual( list(context_mentions.unnest_to_single_mention_per_context()), expected)
def test_sentences(self): exp = [ schema.TextSpan(start=0, end=12, text="Large place."), schema.TextSpan(start=13, end=35, text="Inhabited by everyone.") ] got = list(self.TEST_ENTITY.sentences) self.assertEqual(exp, got)
def test_simple_multibyte_single(self): context_mention = schema.ContextualMention( context=self._CONTEXT_B, mention=schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=13, text="🌵"), entity_id="Q1"), ) context_mention.validate() self.assertEqual( list(context_mention.context.sentences), [ schema.TextSpan(0, 14, "We all live 🌵."), schema.TextSpan(15, 33, "Here in the World.") ]) json_string = json.dumps(context_mention.to_json()) got = schema.ContextualMention.from_json(json.loads(json_string)) self.assertEqual(context_mention, got) # Perturb text to invalidate the sentence spans. perturbed = copy.deepcopy(context_mention) perturbed.context.text = perturbed.context.text[::2] with self.assertRaises(ValueError): perturbed.validate()
def test_round_trip(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=16, text="here"), entity_id="Q1"), schema.Mention( example_id="30ba", mention_span=schema.TextSpan(start=30, end=35, text="World"), entity_id="Q1", metadata={"bin_name": "bin03_100-1000"}, ), ]) context_mentions.validate() self.assertEqual( list(context_mentions.context.sentences), [ schema.TextSpan(0, 17, "We all live here."), schema.TextSpan(18, 36, "Here in the World.") ]) json_string = json.dumps(context_mentions.to_json()) got = schema.ContextualMentions.from_json(json.loads(json_string)) self.assertEqual(context_mentions, got) # Perturb text to invalidate the sentence spans. perturbed = copy.deepcopy(context_mentions) perturbed.context.text = perturbed.context.text[::2] with self.assertRaises(ValueError): perturbed.validate()
def test_truncate_skips_mention_that_crosses_boundary(self): contextual_mention = schema.ContextualMention( context=self._CONTEXT_C, mention=schema.Mention( # (0,6) covers the first two sentences (0,2), (3,6). mention_span=schema.TextSpan(start=0, end=6, text="We all"), entity_id="Qx", example_id="82e3")) self.assertIsNone(contextual_mention.truncate(window_size=1))
def test_round_trip_including_section(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=16, text="here"), entity_id="Q1"), schema.Mention( example_id="30ba", mention_span=schema.TextSpan(start=30, end=35, text="World"), entity_id="Q1", metadata={"bin_name": "bin03_100-1000"}, ), ]) context_mentions.validate() json_string = json.dumps(context_mentions.to_json()) got = schema.ContextualMentions.from_json(json.loads(json_string)) self.assertEqual(context_mentions, got)
def test_simple_multibyte(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_B, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=13, text="🌵"), entity_id="Q1"), ]) context_mentions.validate() self.assertEqual( list(context_mentions.context.sentences), [ schema.TextSpan(0, 14, "We all live 🌵."), schema.TextSpan(15, 33, "Here in the World.") ]) json_string = json.dumps(context_mentions.to_json()) got = schema.ContextualMentions.from_json(json.loads(json_string)) self.assertEqual(context_mentions, got)
def test_round_trip(self): example = schema.MentionEntityPair( contextual_mention=schema.ContextualMention( context=ContextualMentionsTest._CONTEXT_B, mention=schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=13, text="🌵"), entity_id="Q1")), entity=EntityTest.TEST_ENTITY) example.validate() json_string = json.dumps(example.to_json()) got = schema.MentionEntityPair.from_json(json.loads(json_string)) self.assertEqual(example, got)
def test_invalid_raises(self): with self.assertRaises(ValueError): _ = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan( # Start and end does not refer to substring 'here'. start=0, end=5, text="here"), entity_id="Q1"), ])
def test_text_span_from(self): got = schema.TextSpan.from_elements( start=13, end=35, context="Large place. Inhabited by everyone.") exp = schema.TextSpan(start=13, end=35, text="Inhabited by everyone.") self.assertEqual(exp, got)
class ContextualMentionsTest(parameterized.TestCase): _CONTEXT_A = context = schema.Context( document_title="Planet Earth", document_url="www.xyz.com", document_id="xyz-123", language="en", text="We all live here. Here in the World.", sentence_spans=(schema.Span(0, 17), schema.Span(18, 36))) _CONTEXT_B = schema.Context( document_title="Planet Earth", document_url="www.xyz.com", document_id="xyz-123", section_title="Intro", language="en", text="We all live 🌵. Here in the World.", sentence_spans=(schema.Span(0, 14), schema.Span(15, 33))) _CONTEXT_C = schema.Context( document_title="Planet Earth", document_url="www.xyz.com", document_id="xyz-123", section_title="Intro", language="en", text="We all live 🌵. Here in the World.", # For brevity, mimick sentences using phrases. sentence_spans=( schema.Span(0, 2), # We schema.Span(3, 6), # all schema.Span(7, 14), # live 🌵. schema.Span(15, 19), # Here schema.Span(20, 26), # in the schema.Span(27, 33), # World. )) # Two test mentions consistent with CONTEXT_C. _ALL_MENTION = schema.Mention( mention_span=schema.TextSpan(start=3, end=6, text="all"), entity_id="Qx", example_id="82e3") _WORLD_MENTION = schema.Mention( mention_span=schema.TextSpan(start=27, end=32, text="World"), entity_id="Q1", example_id="9024f") def test_round_trip(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=16, text="here"), entity_id="Q1"), schema.Mention( example_id="30ba", mention_span=schema.TextSpan(start=30, end=35, text="World"), entity_id="Q1", metadata={"bin_name": "bin03_100-1000"}, ), ]) context_mentions.validate() self.assertEqual( list(context_mentions.context.sentences), [ schema.TextSpan(0, 17, "We all live here."), schema.TextSpan(18, 36, "Here in the World.") ]) json_string = json.dumps(context_mentions.to_json()) got = schema.ContextualMentions.from_json(json.loads(json_string)) self.assertEqual(context_mentions, got) # Perturb text to invalidate the sentence spans. perturbed = copy.deepcopy(context_mentions) perturbed.context.text = perturbed.context.text[::2] with self.assertRaises(ValueError): perturbed.validate() def test_context_add_sentences(self): test_context = schema.add_sentence_spans( schema.Context( document_title="Planet Earth", document_url="www.xyz.com", document_id="xyz-123", language="en", text="We all live here. Here in the World.", sentence_spans=()), # OK to pass list instead of tuple to add_sentence_spans. sentence_spans=[schema.Span(0, 17), schema.Span(18, 36)]) self.assertEqual(test_context, self._CONTEXT_A) def test_round_trip_including_section(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=16, text="here"), entity_id="Q1"), schema.Mention( example_id="30ba", mention_span=schema.TextSpan(start=30, end=35, text="World"), entity_id="Q1", metadata={"bin_name": "bin03_100-1000"}, ), ]) context_mentions.validate() json_string = json.dumps(context_mentions.to_json()) got = schema.ContextualMentions.from_json(json.loads(json_string)) self.assertEqual(context_mentions, got) def test_simple_multibyte(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_B, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=13, text="🌵"), entity_id="Q1"), ]) context_mentions.validate() self.assertEqual( list(context_mentions.context.sentences), [ schema.TextSpan(0, 14, "We all live 🌵."), schema.TextSpan(15, 33, "Here in the World.") ]) json_string = json.dumps(context_mentions.to_json()) got = schema.ContextualMentions.from_json(json.loads(json_string)) self.assertEqual(context_mentions, got) def test_simple_multibyte_single(self): context_mention = schema.ContextualMention( context=self._CONTEXT_B, mention=schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=13, text="🌵"), entity_id="Q1"), ) context_mention.validate() self.assertEqual( list(context_mention.context.sentences), [ schema.TextSpan(0, 14, "We all live 🌵."), schema.TextSpan(15, 33, "Here in the World.") ]) json_string = json.dumps(context_mention.to_json()) got = schema.ContextualMention.from_json(json.loads(json_string)) self.assertEqual(context_mention, got) # Perturb text to invalidate the sentence spans. perturbed = copy.deepcopy(context_mention) perturbed.context.text = perturbed.context.text[::2] with self.assertRaises(ValueError): perturbed.validate() def test_invalid_raises(self): with self.assertRaises(ValueError): _ = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan( # Start and end does not refer to substring 'here'. start=0, end=5, text="here"), entity_id="Q1"), ]) def test_unnest(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=16, text="here"), entity_id="Q1"), schema.Mention( example_id="30ba", mention_span=schema.TextSpan(start=30, end=35, text="World"), entity_id="Q1", ), ]) context_mentions.validate() expected = [ schema.ContextualMention( context=self._CONTEXT_A, mention=schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=16, text="here"), entity_id="Q1"), ), schema.ContextualMention( context=self._CONTEXT_A, mention=schema.Mention( example_id="30ba", mention_span=schema.TextSpan(start=30, end=35, text="World"), entity_id="Q1", ), ) ] self.assertEqual( list(context_mentions.unnest_to_single_mention_per_context()), expected) def test_unnest_discards_input_without_mentions(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[]) self.assertEmpty( list(context_mentions.unnest_to_single_mention_per_context())) @parameterized.named_parameters([ dict( testcase_name="only_focus_sentence", contextual_mention=schema.ContextualMention( context=_CONTEXT_C, mention=_ALL_MENTION), window_size=0, expected_sentences_text="all", expected_mention=schema.Mention( # Shifted mention span because first "sentence" gets dropped. mention_span=schema.TextSpan(start=0, end=3, text="all"), entity_id="Qx", example_id="82e3")), dict( testcase_name="window_1", contextual_mention=schema.ContextualMention( context=_CONTEXT_C, mention=_ALL_MENTION), window_size=1, expected_sentences_text="We/all/live 🌵.", expected_mention=_ALL_MENTION), dict( testcase_name="window_exceeds_context", contextual_mention=schema.ContextualMention( context=_CONTEXT_C, mention=_ALL_MENTION), window_size=10, expected_sentences_text="We/all/live 🌵./Here/in the/World.", expected_mention=_ALL_MENTION), dict( testcase_name="window_2_carryover_to_right", contextual_mention=schema.ContextualMention( context=_CONTEXT_C, mention=_ALL_MENTION), window_size=2, expected_sentences_text="We/all/live 🌵./Here/in the", expected_mention=_ALL_MENTION), dict( testcase_name="window_2_carryover_to_left", contextual_mention=schema.ContextualMention( context=_CONTEXT_C, mention=_WORLD_MENTION), window_size=2, expected_sentences_text="all/live 🌵./Here/in the/World.", expected_mention=schema.Mention( # Shifted mention span because first "sentence" gets dropped. mention_span=schema.TextSpan(start=24, end=29, text="World"), entity_id="Q1", example_id="9024f", )), ]) def test_truncate(self, contextual_mention, window_size, expected_sentences_text, expected_mention): truncated = contextual_mention.truncate(window_size) # For brevity, the truncated ContextualMention is validated only in terms of # its concatenated sentences (delimited with "/" for readability). self.assertEqual( "/".join(s.text for s in truncated.context.sentences), expected_sentences_text, msg=f"In {truncated}") self.assertEqual(truncated.mention, expected_mention) def test_truncate_skips_mention_that_crosses_boundary(self): contextual_mention = schema.ContextualMention( context=self._CONTEXT_C, mention=schema.Mention( # (0,6) covers the first two sentences (0,2), (3,6). mention_span=schema.TextSpan(start=0, end=6, text="We all"), entity_id="Qx", example_id="82e3")) self.assertIsNone(contextual_mention.truncate(window_size=1)) def test_truncate_skips_mention_if_context_has_no_sentences(self): contextual_mention = schema.ContextualMention( context=schema.Context( document_title="Planet Earth", document_url="www.xyz.com", document_id="xyz-123", section_title="Intro", language="en", text="We all live 🌵. Here in the World.", sentence_spans=()), mention=self._ALL_MENTION) self.assertIsNone(contextual_mention.truncate(window_size=1)) def test_truncate_raises(self): with self.assertRaises(ValueError): self._CONTEXT_A.truncate(focus=0, window_size=-1) with self.assertRaises(IndexError): self._CONTEXT_B.truncate(focus=-1, window_size=0) with self.assertRaises(IndexError): self._CONTEXT_C.truncate(focus=50, window_size=0)