def test_unnest(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=16, text="here"), entity_id="Q1"), schema.Mention( example_id="30ba", mention_span=schema.TextSpan(start=30, end=35, text="World"), entity_id="Q1", ), ]) context_mentions.validate() expected = [ schema.ContextualMention( context=self._CONTEXT_A, mention=schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=16, text="here"), entity_id="Q1"), ), schema.ContextualMention( context=self._CONTEXT_A, mention=schema.Mention( example_id="30ba", mention_span=schema.TextSpan(start=30, end=35, text="World"), entity_id="Q1", ), ) ] self.assertEqual( list(context_mentions.unnest_to_single_mention_per_context()), expected)
def test_round_trip(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=16, text="here"), entity_id="Q1"), schema.Mention( example_id="30ba", mention_span=schema.TextSpan(start=30, end=35, text="World"), entity_id="Q1", metadata={"bin_name": "bin03_100-1000"}, ), ]) context_mentions.validate() self.assertEqual( list(context_mentions.context.sentences), [ schema.TextSpan(0, 17, "We all live here."), schema.TextSpan(18, 36, "Here in the World.") ]) json_string = json.dumps(context_mentions.to_json()) got = schema.ContextualMentions.from_json(json.loads(json_string)) self.assertEqual(context_mentions, got) # Perturb text to invalidate the sentence spans. perturbed = copy.deepcopy(context_mentions) perturbed.context.text = perturbed.context.text[::2] with self.assertRaises(ValueError): perturbed.validate()
def test_invalid_raises(self): with self.assertRaises(ValueError): _ = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan( # Start and end does not refer to substring 'here'. start=0, end=5, text="here"), entity_id="Q1"), ])
def test_round_trip_including_section(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=16, text="here"), entity_id="Q1"), schema.Mention( example_id="30ba", mention_span=schema.TextSpan(start=30, end=35, text="World"), entity_id="Q1", metadata={"bin_name": "bin03_100-1000"}, ), ]) context_mentions.validate() json_string = json.dumps(context_mentions.to_json()) got = schema.ContextualMentions.from_json(json.loads(json_string)) self.assertEqual(context_mentions, got)
def test_simple_multibyte(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_B, mentions=[ schema.Mention( example_id="12fe", mention_span=schema.TextSpan(start=12, end=13, text="🌵"), entity_id="Q1"), ]) context_mentions.validate() self.assertEqual( list(context_mentions.context.sentences), [ schema.TextSpan(0, 14, "We all live 🌵."), schema.TextSpan(15, 33, "Here in the World.") ]) json_string = json.dumps(context_mentions.to_json()) got = schema.ContextualMentions.from_json(json.loads(json_string)) self.assertEqual(context_mentions, got)
def test_unnest_discards_input_without_mentions(self): context_mentions = schema.ContextualMentions( context=self._CONTEXT_A, mentions=[]) self.assertEmpty( list(context_mentions.unnest_to_single_mention_per_context()))