Esempio n. 1
0
  def test_round_trip(self):
    context_mentions = schema.ContextualMentions(
        context=self._CONTEXT_A,
        mentions=[
            schema.Mention(
                example_id="12fe",
                mention_span=schema.TextSpan(start=12, end=16, text="here"),
                entity_id="Q1"),
            schema.Mention(
                example_id="30ba",
                mention_span=schema.TextSpan(start=30, end=35, text="World"),
                entity_id="Q1",
                metadata={"bin_name": "bin03_100-1000"},
            ),
        ])
    context_mentions.validate()

    self.assertEqual(
        list(context_mentions.context.sentences), [
            schema.TextSpan(0, 17, "We all live here."),
            schema.TextSpan(18, 36, "Here in the World.")
        ])

    json_string = json.dumps(context_mentions.to_json())
    got = schema.ContextualMentions.from_json(json.loads(json_string))
    self.assertEqual(context_mentions, got)

    # Perturb text to invalidate the sentence spans.
    perturbed = copy.deepcopy(context_mentions)
    perturbed.context.text = perturbed.context.text[::2]
    with self.assertRaises(ValueError):
      perturbed.validate()
Esempio n. 2
0
 def test_unnest(self):
   context_mentions = schema.ContextualMentions(
       context=self._CONTEXT_A,
       mentions=[
           schema.Mention(
               example_id="12fe",
               mention_span=schema.TextSpan(start=12, end=16, text="here"),
               entity_id="Q1"),
           schema.Mention(
               example_id="30ba",
               mention_span=schema.TextSpan(start=30, end=35, text="World"),
               entity_id="Q1",
           ),
       ])
   context_mentions.validate()
   expected = [
       schema.ContextualMention(
           context=self._CONTEXT_A,
           mention=schema.Mention(
               example_id="12fe",
               mention_span=schema.TextSpan(start=12, end=16, text="here"),
               entity_id="Q1"),
       ),
       schema.ContextualMention(
           context=self._CONTEXT_A,
           mention=schema.Mention(
               example_id="30ba",
               mention_span=schema.TextSpan(start=30, end=35, text="World"),
               entity_id="Q1",
           ),
       )
   ]
   self.assertEqual(
       list(context_mentions.unnest_to_single_mention_per_context()), expected)
Esempio n. 3
0
  def test_simple_multibyte_single(self):
    context_mention = schema.ContextualMention(
        context=self._CONTEXT_B,
        mention=schema.Mention(
            example_id="12fe",
            mention_span=schema.TextSpan(start=12, end=13, text="🌵"),
            entity_id="Q1"),
    )
    context_mention.validate()

    self.assertEqual(
        list(context_mention.context.sentences), [
            schema.TextSpan(0, 14, "We all live 🌵."),
            schema.TextSpan(15, 33, "Here in the World.")
        ])

    json_string = json.dumps(context_mention.to_json())
    got = schema.ContextualMention.from_json(json.loads(json_string))
    self.assertEqual(context_mention, got)

    # Perturb text to invalidate the sentence spans.
    perturbed = copy.deepcopy(context_mention)
    perturbed.context.text = perturbed.context.text[::2]
    with self.assertRaises(ValueError):
      perturbed.validate()
Esempio n. 4
0
 def test_truncate_skips_mention_that_crosses_boundary(self):
   contextual_mention = schema.ContextualMention(
       context=self._CONTEXT_C,
       mention=schema.Mention(
           # (0,6) covers the first two sentences (0,2), (3,6).
           mention_span=schema.TextSpan(start=0, end=6, text="We all"),
           entity_id="Qx",
           example_id="82e3"))
   self.assertIsNone(contextual_mention.truncate(window_size=1))
Esempio n. 5
0
 def test_round_trip_including_section(self):
   context_mentions = schema.ContextualMentions(
       context=self._CONTEXT_A,
       mentions=[
           schema.Mention(
               example_id="12fe",
               mention_span=schema.TextSpan(start=12, end=16, text="here"),
               entity_id="Q1"),
           schema.Mention(
               example_id="30ba",
               mention_span=schema.TextSpan(start=30, end=35, text="World"),
               entity_id="Q1",
               metadata={"bin_name": "bin03_100-1000"},
           ),
       ])
   context_mentions.validate()
   json_string = json.dumps(context_mentions.to_json())
   got = schema.ContextualMentions.from_json(json.loads(json_string))
   self.assertEqual(context_mentions, got)
Esempio n. 6
0
  def test_round_trip(self):
    example = schema.MentionEntityPair(
        contextual_mention=schema.ContextualMention(
            context=ContextualMentionsTest._CONTEXT_B,
            mention=schema.Mention(
                example_id="12fe",
                mention_span=schema.TextSpan(start=12, end=13, text="🌵"),
                entity_id="Q1")),
        entity=EntityTest.TEST_ENTITY)
    example.validate()

    json_string = json.dumps(example.to_json())
    got = schema.MentionEntityPair.from_json(json.loads(json_string))
    self.assertEqual(example, got)
Esempio n. 7
0
 def test_invalid_raises(self):
   with self.assertRaises(ValueError):
     _ = schema.ContextualMentions(
         context=self._CONTEXT_A,
         mentions=[
             schema.Mention(
                 example_id="12fe",
                 mention_span=schema.TextSpan(
                     # Start and end does not refer to substring 'here'.
                     start=0,
                     end=5,
                     text="here"),
                 entity_id="Q1"),
         ])
Esempio n. 8
0
  def test_simple_multibyte(self):
    context_mentions = schema.ContextualMentions(
        context=self._CONTEXT_B,
        mentions=[
            schema.Mention(
                example_id="12fe",
                mention_span=schema.TextSpan(start=12, end=13, text="🌵"),
                entity_id="Q1"),
        ])
    context_mentions.validate()

    self.assertEqual(
        list(context_mentions.context.sentences), [
            schema.TextSpan(0, 14, "We all live 🌵."),
            schema.TextSpan(15, 33, "Here in the World.")
        ])

    json_string = json.dumps(context_mentions.to_json())
    got = schema.ContextualMentions.from_json(json.loads(json_string))
    self.assertEqual(context_mentions, got)
Esempio n. 9
0
class ContextualMentionsTest(parameterized.TestCase):
  _CONTEXT_A = context = schema.Context(
      document_title="Planet Earth",
      document_url="www.xyz.com",
      document_id="xyz-123",
      language="en",
      text="We all live here. Here in the World.",
      sentence_spans=(schema.Span(0, 17), schema.Span(18, 36)))
  _CONTEXT_B = schema.Context(
      document_title="Planet Earth",
      document_url="www.xyz.com",
      document_id="xyz-123",
      section_title="Intro",
      language="en",
      text="We all live 🌵. Here in the World.",
      sentence_spans=(schema.Span(0, 14), schema.Span(15, 33)))
  _CONTEXT_C = schema.Context(
      document_title="Planet Earth",
      document_url="www.xyz.com",
      document_id="xyz-123",
      section_title="Intro",
      language="en",
      text="We all live 🌵. Here in the World.",
      # For brevity, mimick sentences using phrases.
      sentence_spans=(
          schema.Span(0, 2),  # We
          schema.Span(3, 6),  # all
          schema.Span(7, 14),  # live 🌵.
          schema.Span(15, 19),  # Here
          schema.Span(20, 26),  # in the
          schema.Span(27, 33),  # World.
      ))
  # Two test mentions consistent with CONTEXT_C.
  _ALL_MENTION = schema.Mention(
      mention_span=schema.TextSpan(start=3, end=6, text="all"),
      entity_id="Qx",
      example_id="82e3")
  _WORLD_MENTION = schema.Mention(
      mention_span=schema.TextSpan(start=27, end=32, text="World"),
      entity_id="Q1",
      example_id="9024f")

  def test_round_trip(self):
    context_mentions = schema.ContextualMentions(
        context=self._CONTEXT_A,
        mentions=[
            schema.Mention(
                example_id="12fe",
                mention_span=schema.TextSpan(start=12, end=16, text="here"),
                entity_id="Q1"),
            schema.Mention(
                example_id="30ba",
                mention_span=schema.TextSpan(start=30, end=35, text="World"),
                entity_id="Q1",
                metadata={"bin_name": "bin03_100-1000"},
            ),
        ])
    context_mentions.validate()

    self.assertEqual(
        list(context_mentions.context.sentences), [
            schema.TextSpan(0, 17, "We all live here."),
            schema.TextSpan(18, 36, "Here in the World.")
        ])

    json_string = json.dumps(context_mentions.to_json())
    got = schema.ContextualMentions.from_json(json.loads(json_string))
    self.assertEqual(context_mentions, got)

    # Perturb text to invalidate the sentence spans.
    perturbed = copy.deepcopy(context_mentions)
    perturbed.context.text = perturbed.context.text[::2]
    with self.assertRaises(ValueError):
      perturbed.validate()

  def test_context_add_sentences(self):
    test_context = schema.add_sentence_spans(
        schema.Context(
            document_title="Planet Earth",
            document_url="www.xyz.com",
            document_id="xyz-123",
            language="en",
            text="We all live here. Here in the World.",
            sentence_spans=()),
        # OK to pass list instead of tuple to add_sentence_spans.
        sentence_spans=[schema.Span(0, 17),
                        schema.Span(18, 36)])
    self.assertEqual(test_context, self._CONTEXT_A)

  def test_round_trip_including_section(self):
    context_mentions = schema.ContextualMentions(
        context=self._CONTEXT_A,
        mentions=[
            schema.Mention(
                example_id="12fe",
                mention_span=schema.TextSpan(start=12, end=16, text="here"),
                entity_id="Q1"),
            schema.Mention(
                example_id="30ba",
                mention_span=schema.TextSpan(start=30, end=35, text="World"),
                entity_id="Q1",
                metadata={"bin_name": "bin03_100-1000"},
            ),
        ])
    context_mentions.validate()
    json_string = json.dumps(context_mentions.to_json())
    got = schema.ContextualMentions.from_json(json.loads(json_string))
    self.assertEqual(context_mentions, got)

  def test_simple_multibyte(self):
    context_mentions = schema.ContextualMentions(
        context=self._CONTEXT_B,
        mentions=[
            schema.Mention(
                example_id="12fe",
                mention_span=schema.TextSpan(start=12, end=13, text="🌵"),
                entity_id="Q1"),
        ])
    context_mentions.validate()

    self.assertEqual(
        list(context_mentions.context.sentences), [
            schema.TextSpan(0, 14, "We all live 🌵."),
            schema.TextSpan(15, 33, "Here in the World.")
        ])

    json_string = json.dumps(context_mentions.to_json())
    got = schema.ContextualMentions.from_json(json.loads(json_string))
    self.assertEqual(context_mentions, got)

  def test_simple_multibyte_single(self):
    context_mention = schema.ContextualMention(
        context=self._CONTEXT_B,
        mention=schema.Mention(
            example_id="12fe",
            mention_span=schema.TextSpan(start=12, end=13, text="🌵"),
            entity_id="Q1"),
    )
    context_mention.validate()

    self.assertEqual(
        list(context_mention.context.sentences), [
            schema.TextSpan(0, 14, "We all live 🌵."),
            schema.TextSpan(15, 33, "Here in the World.")
        ])

    json_string = json.dumps(context_mention.to_json())
    got = schema.ContextualMention.from_json(json.loads(json_string))
    self.assertEqual(context_mention, got)

    # Perturb text to invalidate the sentence spans.
    perturbed = copy.deepcopy(context_mention)
    perturbed.context.text = perturbed.context.text[::2]
    with self.assertRaises(ValueError):
      perturbed.validate()

  def test_invalid_raises(self):
    with self.assertRaises(ValueError):
      _ = schema.ContextualMentions(
          context=self._CONTEXT_A,
          mentions=[
              schema.Mention(
                  example_id="12fe",
                  mention_span=schema.TextSpan(
                      # Start and end does not refer to substring 'here'.
                      start=0,
                      end=5,
                      text="here"),
                  entity_id="Q1"),
          ])

  def test_unnest(self):
    context_mentions = schema.ContextualMentions(
        context=self._CONTEXT_A,
        mentions=[
            schema.Mention(
                example_id="12fe",
                mention_span=schema.TextSpan(start=12, end=16, text="here"),
                entity_id="Q1"),
            schema.Mention(
                example_id="30ba",
                mention_span=schema.TextSpan(start=30, end=35, text="World"),
                entity_id="Q1",
            ),
        ])
    context_mentions.validate()
    expected = [
        schema.ContextualMention(
            context=self._CONTEXT_A,
            mention=schema.Mention(
                example_id="12fe",
                mention_span=schema.TextSpan(start=12, end=16, text="here"),
                entity_id="Q1"),
        ),
        schema.ContextualMention(
            context=self._CONTEXT_A,
            mention=schema.Mention(
                example_id="30ba",
                mention_span=schema.TextSpan(start=30, end=35, text="World"),
                entity_id="Q1",
            ),
        )
    ]
    self.assertEqual(
        list(context_mentions.unnest_to_single_mention_per_context()), expected)

  def test_unnest_discards_input_without_mentions(self):
    context_mentions = schema.ContextualMentions(
        context=self._CONTEXT_A, mentions=[])
    self.assertEmpty(
        list(context_mentions.unnest_to_single_mention_per_context()))

  @parameterized.named_parameters([
      dict(
          testcase_name="only_focus_sentence",
          contextual_mention=schema.ContextualMention(
              context=_CONTEXT_C, mention=_ALL_MENTION),
          window_size=0,
          expected_sentences_text="all",
          expected_mention=schema.Mention(
              # Shifted mention span because first "sentence" gets dropped.
              mention_span=schema.TextSpan(start=0, end=3, text="all"),
              entity_id="Qx",
              example_id="82e3")),
      dict(
          testcase_name="window_1",
          contextual_mention=schema.ContextualMention(
              context=_CONTEXT_C, mention=_ALL_MENTION),
          window_size=1,
          expected_sentences_text="We/all/live 🌵.",
          expected_mention=_ALL_MENTION),
      dict(
          testcase_name="window_exceeds_context",
          contextual_mention=schema.ContextualMention(
              context=_CONTEXT_C, mention=_ALL_MENTION),
          window_size=10,
          expected_sentences_text="We/all/live 🌵./Here/in the/World.",
          expected_mention=_ALL_MENTION),
      dict(
          testcase_name="window_2_carryover_to_right",
          contextual_mention=schema.ContextualMention(
              context=_CONTEXT_C, mention=_ALL_MENTION),
          window_size=2,
          expected_sentences_text="We/all/live 🌵./Here/in the",
          expected_mention=_ALL_MENTION),
      dict(
          testcase_name="window_2_carryover_to_left",
          contextual_mention=schema.ContextualMention(
              context=_CONTEXT_C, mention=_WORLD_MENTION),
          window_size=2,
          expected_sentences_text="all/live 🌵./Here/in the/World.",
          expected_mention=schema.Mention(
              # Shifted mention span because first "sentence" gets dropped.
              mention_span=schema.TextSpan(start=24, end=29, text="World"),
              entity_id="Q1",
              example_id="9024f",
          )),
  ])
  def test_truncate(self, contextual_mention, window_size,
                    expected_sentences_text, expected_mention):
    truncated = contextual_mention.truncate(window_size)

    # For brevity, the truncated ContextualMention is validated only in terms of
    # its concatenated sentences (delimited with "/" for readability).
    self.assertEqual(
        "/".join(s.text for s in truncated.context.sentences),
        expected_sentences_text,
        msg=f"In {truncated}")

    self.assertEqual(truncated.mention, expected_mention)

  def test_truncate_skips_mention_that_crosses_boundary(self):
    contextual_mention = schema.ContextualMention(
        context=self._CONTEXT_C,
        mention=schema.Mention(
            # (0,6) covers the first two sentences (0,2), (3,6).
            mention_span=schema.TextSpan(start=0, end=6, text="We all"),
            entity_id="Qx",
            example_id="82e3"))
    self.assertIsNone(contextual_mention.truncate(window_size=1))

  def test_truncate_skips_mention_if_context_has_no_sentences(self):
    contextual_mention = schema.ContextualMention(
        context=schema.Context(
            document_title="Planet Earth",
            document_url="www.xyz.com",
            document_id="xyz-123",
            section_title="Intro",
            language="en",
            text="We all live 🌵. Here in the World.",
            sentence_spans=()),
        mention=self._ALL_MENTION)
    self.assertIsNone(contextual_mention.truncate(window_size=1))

  def test_truncate_raises(self):
    with self.assertRaises(ValueError):
      self._CONTEXT_A.truncate(focus=0, window_size=-1)
    with self.assertRaises(IndexError):
      self._CONTEXT_B.truncate(focus=-1, window_size=0)
    with self.assertRaises(IndexError):
      self._CONTEXT_C.truncate(focus=50, window_size=0)