예제 #1
0
 def test_unnest(self):
   context_mentions = schema.ContextualMentions(
       context=self._CONTEXT_A,
       mentions=[
           schema.Mention(
               example_id="12fe",
               mention_span=schema.TextSpan(start=12, end=16, text="here"),
               entity_id="Q1"),
           schema.Mention(
               example_id="30ba",
               mention_span=schema.TextSpan(start=30, end=35, text="World"),
               entity_id="Q1",
           ),
       ])
   context_mentions.validate()
   expected = [
       schema.ContextualMention(
           context=self._CONTEXT_A,
           mention=schema.Mention(
               example_id="12fe",
               mention_span=schema.TextSpan(start=12, end=16, text="here"),
               entity_id="Q1"),
       ),
       schema.ContextualMention(
           context=self._CONTEXT_A,
           mention=schema.Mention(
               example_id="30ba",
               mention_span=schema.TextSpan(start=30, end=35, text="World"),
               entity_id="Q1",
           ),
       )
   ]
   self.assertEqual(
       list(context_mentions.unnest_to_single_mention_per_context()), expected)
예제 #2
0
  def test_round_trip(self):
    context_mentions = schema.ContextualMentions(
        context=self._CONTEXT_A,
        mentions=[
            schema.Mention(
                example_id="12fe",
                mention_span=schema.TextSpan(start=12, end=16, text="here"),
                entity_id="Q1"),
            schema.Mention(
                example_id="30ba",
                mention_span=schema.TextSpan(start=30, end=35, text="World"),
                entity_id="Q1",
                metadata={"bin_name": "bin03_100-1000"},
            ),
        ])
    context_mentions.validate()

    self.assertEqual(
        list(context_mentions.context.sentences), [
            schema.TextSpan(0, 17, "We all live here."),
            schema.TextSpan(18, 36, "Here in the World.")
        ])

    json_string = json.dumps(context_mentions.to_json())
    got = schema.ContextualMentions.from_json(json.loads(json_string))
    self.assertEqual(context_mentions, got)

    # Perturb text to invalidate the sentence spans.
    perturbed = copy.deepcopy(context_mentions)
    perturbed.context.text = perturbed.context.text[::2]
    with self.assertRaises(ValueError):
      perturbed.validate()
예제 #3
0
 def test_invalid_raises(self):
   with self.assertRaises(ValueError):
     _ = schema.ContextualMentions(
         context=self._CONTEXT_A,
         mentions=[
             schema.Mention(
                 example_id="12fe",
                 mention_span=schema.TextSpan(
                     # Start and end does not refer to substring 'here'.
                     start=0,
                     end=5,
                     text="here"),
                 entity_id="Q1"),
         ])
예제 #4
0
 def test_round_trip_including_section(self):
   context_mentions = schema.ContextualMentions(
       context=self._CONTEXT_A,
       mentions=[
           schema.Mention(
               example_id="12fe",
               mention_span=schema.TextSpan(start=12, end=16, text="here"),
               entity_id="Q1"),
           schema.Mention(
               example_id="30ba",
               mention_span=schema.TextSpan(start=30, end=35, text="World"),
               entity_id="Q1",
               metadata={"bin_name": "bin03_100-1000"},
           ),
       ])
   context_mentions.validate()
   json_string = json.dumps(context_mentions.to_json())
   got = schema.ContextualMentions.from_json(json.loads(json_string))
   self.assertEqual(context_mentions, got)
예제 #5
0
  def test_simple_multibyte(self):
    context_mentions = schema.ContextualMentions(
        context=self._CONTEXT_B,
        mentions=[
            schema.Mention(
                example_id="12fe",
                mention_span=schema.TextSpan(start=12, end=13, text="🌵"),
                entity_id="Q1"),
        ])
    context_mentions.validate()

    self.assertEqual(
        list(context_mentions.context.sentences), [
            schema.TextSpan(0, 14, "We all live 🌵."),
            schema.TextSpan(15, 33, "Here in the World.")
        ])

    json_string = json.dumps(context_mentions.to_json())
    got = schema.ContextualMentions.from_json(json.loads(json_string))
    self.assertEqual(context_mentions, got)
예제 #6
0
 def test_unnest_discards_input_without_mentions(self):
   context_mentions = schema.ContextualMentions(
       context=self._CONTEXT_A, mentions=[])
   self.assertEmpty(
       list(context_mentions.unnest_to_single_mention_per_context()))