def test_skip_gram_sample_random_skips_default_seed(self):
    """Tests outputs are still random when no op-level seed is specified."""
    # This is needed since tests set a graph-level seed by default. We want to
    # explicitly avoid setting both graph-level seed and op-level seed, to
    # simulate behavior under non-test settings when the user doesn't provide a
    # seed to us. This results in random_seed.get_seed() returning None for both
    # seeds, forcing the C++ kernel to execute its default seed logic.
    random_seed.set_random_seed(None)

    # Uses an input tensor with 10 words, with possible skip ranges in [1,
    # 5]. Thus, the probability that two random samplings would result in the
    # same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test being
    # flaky).
    input_tensor = constant_op.constant([str(x) for x in range(10)])

    # Do not provide an op-level seed here!
    tokens_1, labels_1 = text.skip_gram_sample(
        input_tensor, min_skips=1, max_skips=5)
    tokens_2, labels_2 = text.skip_gram_sample(
        input_tensor, min_skips=1, max_skips=5)

    with self.test_session() as sess:
      tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run(
          [tokens_1, labels_1, tokens_2, labels_2])

    if len(tokens_1_eval) == len(tokens_2_eval):
      self.assertNotEqual(tokens_1_eval.tolist(), tokens_2_eval.tolist())
    if len(labels_1_eval) == len(labels_2_eval):
      self.assertNotEqual(labels_1_eval.tolist(), labels_2_eval.tolist())
示例#2
0
    def test_skip_gram_sample_random_skips_default_seed(self):
        """Tests outputs are still random when no op-level seed is specified."""
        # This is needed since tests set a graph-level seed by default. We want to
        # explicitly avoid setting both graph-level seed and op-level seed, to
        # simulate behavior under non-test settings when the user doesn't provide a
        # seed to us. This results in random_seed.get_seed() returning None for both
        # seeds, forcing the C++ kernel to execute its default seed logic.
        random_seed.set_random_seed(None)

        # Uses an input tensor with 10 words, with possible skip ranges in [1,
        # 5]. Thus, the probability that two random samplings would result in the
        # same outputs is 1/5^10 ~ 1e-7 (aka the probability of this test being
        # flaky).
        input_tensor = constant_op.constant([str(x) for x in range(10)])

        # Do not provide an op-level seed here!
        tokens_1, labels_1 = text.skip_gram_sample(input_tensor,
                                                   min_skips=1,
                                                   max_skips=5)
        tokens_2, labels_2 = text.skip_gram_sample(input_tensor,
                                                   min_skips=1,
                                                   max_skips=5)

        with self.test_session() as sess:
            tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run(
                [tokens_1, labels_1, tokens_2, labels_2])

        if len(tokens_1_eval) == len(tokens_2_eval):
            self.assertNotEqual(tokens_1_eval.tolist(), tokens_2_eval.tolist())
        if len(labels_1_eval) == len(labels_2_eval):
            self.assertNotEqual(labels_1_eval.tolist(), labels_2_eval.tolist())
示例#3
0
    def test_skip_gram_sample_skips_0(self):
        """Tests skip-gram with min_skips = max_skips = 0."""
        input_tensor = constant_op.constant([b"the", b"quick", b"brown"])

        # If emit_self_as_target is False (default), output will be empty.
        tokens, labels = text.skip_gram_sample(input_tensor,
                                               min_skips=0,
                                               max_skips=0,
                                               emit_self_as_target=False)
        with self.test_session():
            self.assertEqual(0, tokens.eval().size)
            self.assertEqual(0, labels.eval().size)

        # If emit_self_as_target is True, each token will be its own label.
        tokens, labels = text.skip_gram_sample(input_tensor,
                                               min_skips=0,
                                               max_skips=0,
                                               emit_self_as_target=True)
        expected_tokens, expected_labels = self._split_tokens_labels([
            (b"the", b"the"),
            (b"quick", b"quick"),
            (b"brown", b"brown"),
        ])
        with self.test_session():
            self.assertAllEqual(expected_tokens, tokens.eval())
            self.assertAllEqual(expected_labels, labels.eval())
示例#4
0
    def test_skip_gram_sample_batch(self):
        """Tests skip-gram with batching."""
        input_tensor = constant_op.constant(
            [b"the", b"quick", b"brown", b"fox"])
        tokens, labels = text.skip_gram_sample(input_tensor,
                                               min_skips=1,
                                               max_skips=1,
                                               batch_size=3)
        expected_tokens, expected_labels = self._split_tokens_labels([
            (b"the", b"quick"),
            (b"quick", b"the"),
            (b"quick", b"brown"),
            (b"brown", b"quick"),
            (b"brown", b"fox"),
            (b"fox", b"brown"),
        ])
        with self.test_session() as sess:
            coord = coordinator.Coordinator()
            threads = queue_runner_impl.start_queue_runners(sess=sess,
                                                            coord=coord)

            tokens_eval, labels_eval = sess.run([tokens, labels])
            self.assertAllEqual(expected_tokens[:3], tokens_eval)
            self.assertAllEqual(expected_labels[:3], labels_eval)
            tokens_eval, labels_eval = sess.run([tokens, labels])
            self.assertAllEqual(expected_tokens[3:6], tokens_eval)
            self.assertAllEqual(expected_labels[3:6], labels_eval)

            coord.request_stop()
            coord.join(threads)
 def test_skip_gram_sample_emit_self(self):
   """Tests skip-gram with emit_self_as_target = True."""
   input_tensor = constant_op.constant(
       [b"the", b"quick", b"brown", b"fox", b"jumps"])
   tokens, labels = text.skip_gram_sample(
       input_tensor, min_skips=2, max_skips=2, emit_self_as_target=True)
   expected_tokens, expected_labels = self._split_tokens_labels([
       (b"the", b"the"),
       (b"the", b"quick"),
       (b"the", b"brown"),
       (b"quick", b"the"),
       (b"quick", b"quick"),
       (b"quick", b"brown"),
       (b"quick", b"fox"),
       (b"brown", b"the"),
       (b"brown", b"quick"),
       (b"brown", b"brown"),
       (b"brown", b"fox"),
       (b"brown", b"jumps"),
       (b"fox", b"quick"),
       (b"fox", b"brown"),
       (b"fox", b"fox"),
       (b"fox", b"jumps"),
       (b"jumps", b"brown"),
       (b"jumps", b"fox"),
       (b"jumps", b"jumps"),
   ])
   with self.test_session():
     self.assertAllEqual(expected_tokens, tokens.eval())
     self.assertAllEqual(expected_labels, labels.eval())
  def test_skip_gram_sample_batch(self):
    """Tests skip-gram with batching."""
    input_tensor = constant_op.constant([b"the", b"quick", b"brown", b"fox"])
    tokens, labels = text.skip_gram_sample(
        input_tensor, min_skips=1, max_skips=1, batch_size=3)
    expected_tokens, expected_labels = self._split_tokens_labels([
        (b"the", b"quick"),
        (b"quick", b"the"),
        (b"quick", b"brown"),
        (b"brown", b"quick"),
        (b"brown", b"fox"),
        (b"fox", b"brown"),
    ])
    with self.test_session() as sess:
      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)

      tokens_eval, labels_eval = sess.run([tokens, labels])
      self.assertAllEqual(expected_tokens[:3], tokens_eval)
      self.assertAllEqual(expected_labels[:3], labels_eval)
      tokens_eval, labels_eval = sess.run([tokens, labels])
      self.assertAllEqual(expected_tokens[3:6], tokens_eval)
      self.assertAllEqual(expected_labels[3:6], labels_eval)

      coord.request_stop()
      coord.join(threads)
示例#7
0
    def test_skip_gram_sample_random_skips(self):
        """Tests skip-gram with min_skips != max_skips, with random output."""
        # The number of outputs is non-deterministic in this case, so set random
        # seed to help ensure the outputs remain constant for this test case.
        random_seed.set_random_seed(42)

        input_tensor = constant_op.constant(
            [b"the", b"quick", b"brown", b"fox", b"jumps", b"over"])
        tokens, labels = text.skip_gram_sample(input_tensor,
                                               min_skips=1,
                                               max_skips=2,
                                               seed=9)
        expected_tokens, expected_labels = self._split_tokens_labels([
            (b"the", b"quick"),
            (b"the", b"brown"),
            (b"quick", b"the"),
            (b"quick", b"brown"),
            (b"quick", b"fox"),
            (b"brown", b"the"),
            (b"brown", b"quick"),
            (b"brown", b"fox"),
            (b"brown", b"jumps"),
            (b"fox", b"brown"),
            (b"fox", b"jumps"),
            (b"jumps", b"fox"),
            (b"jumps", b"over"),
            (b"over", b"fox"),
            (b"over", b"jumps"),
        ])
        with self.test_session() as sess:
            tokens_eval, labels_eval = sess.run([tokens, labels])
            self.assertAllEqual(expected_tokens, tokens_eval)
            self.assertAllEqual(expected_labels, labels_eval)
  def test_skip_gram_sample_random_skips(self):
    """Tests skip-gram with min_skips != max_skips, with random output."""
    # The number of outputs is non-deterministic in this case, so set random
    # seed to help ensure the outputs remain constant for this test case.
    random_seed.set_random_seed(42)

    input_tensor = constant_op.constant(
        [b"the", b"quick", b"brown", b"fox", b"jumps", b"over"])
    tokens, labels = text.skip_gram_sample(
        input_tensor, min_skips=1, max_skips=2, seed=9)
    expected_tokens, expected_labels = self._split_tokens_labels([
        (b"the", b"quick"),
        (b"the", b"brown"),
        (b"quick", b"the"),
        (b"quick", b"brown"),
        (b"quick", b"fox"),
        (b"brown", b"the"),
        (b"brown", b"quick"),
        (b"brown", b"fox"),
        (b"brown", b"jumps"),
        (b"fox", b"brown"),
        (b"fox", b"jumps"),
        (b"jumps", b"fox"),
        (b"jumps", b"over"),
        (b"over", b"fox"),
        (b"over", b"jumps"),
    ])
    with self.test_session() as sess:
      tokens_eval, labels_eval = sess.run([tokens, labels])
      self.assertAllEqual(expected_tokens, tokens_eval)
      self.assertAllEqual(expected_labels, labels_eval)
示例#9
0
 def test_skip_gram_sample_emit_self(self):
     """Tests skip-gram with emit_self_as_target = True."""
     input_tensor = constant_op.constant(
         [b"the", b"quick", b"brown", b"fox", b"jumps"])
     tokens, labels = text.skip_gram_sample(input_tensor,
                                            min_skips=2,
                                            max_skips=2,
                                            emit_self_as_target=True)
     expected_tokens, expected_labels = self._split_tokens_labels([
         (b"the", b"the"),
         (b"the", b"quick"),
         (b"the", b"brown"),
         (b"quick", b"the"),
         (b"quick", b"quick"),
         (b"quick", b"brown"),
         (b"quick", b"fox"),
         (b"brown", b"the"),
         (b"brown", b"quick"),
         (b"brown", b"brown"),
         (b"brown", b"fox"),
         (b"brown", b"jumps"),
         (b"fox", b"quick"),
         (b"fox", b"brown"),
         (b"fox", b"fox"),
         (b"fox", b"jumps"),
         (b"jumps", b"brown"),
         (b"jumps", b"fox"),
         (b"jumps", b"jumps"),
     ])
     with self.test_session():
         self.assertAllEqual(expected_tokens, tokens.eval())
         self.assertAllEqual(expected_labels, labels.eval())
  def test_skip_gram_sample_skips_0(self):
    """Tests skip-gram with min_skips = max_skips = 0."""
    input_tensor = constant_op.constant([b"the", b"quick", b"brown"])

    # If emit_self_as_target is False (default), output will be empty.
    tokens, labels = text.skip_gram_sample(
        input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False)
    with self.test_session():
      self.assertEqual(0, tokens.eval().size)
      self.assertEqual(0, labels.eval().size)

    # If emit_self_as_target is True, each token will be its own label.
    tokens, labels = text.skip_gram_sample(
        input_tensor, min_skips=0, max_skips=0, emit_self_as_target=True)
    expected_tokens, expected_labels = self._split_tokens_labels([
        (b"the", b"the"),
        (b"quick", b"quick"),
        (b"brown", b"brown"),
    ])
    with self.test_session():
      self.assertAllEqual(expected_tokens, tokens.eval())
      self.assertAllEqual(expected_labels, labels.eval())
 def test_skip_gram_sample_non_string_input(self):
   """Tests skip-gram with non-string input."""
   input_tensor = constant_op.constant([1, 2, 3], dtype=dtypes.int16)
   tokens, labels = text.skip_gram_sample(
       input_tensor, min_skips=1, max_skips=1)
   expected_tokens, expected_labels = self._split_tokens_labels([
       (1, 2),
       (2, 1),
       (2, 3),
       (3, 2),
   ])
   with self.test_session():
     self.assertAllEqual(expected_tokens, tokens.eval())
     self.assertAllEqual(expected_labels, labels.eval())
 def test_skip_gram_sample_limit_exceeds(self):
   """Tests skip-gram when limit exceeds the length of the input."""
   input_tensor = constant_op.constant([b"foo", b"the", b"quick", b"brown"])
   tokens, labels = text.skip_gram_sample(
       input_tensor, min_skips=1, max_skips=1, start=1, limit=100)
   expected_tokens, expected_labels = self._split_tokens_labels([
       (b"the", b"quick"),
       (b"quick", b"the"),
       (b"quick", b"brown"),
       (b"brown", b"quick"),
   ])
   with self.test_session():
     self.assertAllEqual(expected_tokens, tokens.eval())
     self.assertAllEqual(expected_labels, labels.eval())
示例#13
0
 def test_skip_gram_sample_non_string_input(self):
   """Tests skip-gram with non-string input."""
   input_tensor = constant_op.constant([1, 2, 3], dtype=dtypes.int16)
   tokens, labels = text.skip_gram_sample(
       input_tensor, min_skips=1, max_skips=1)
   expected_tokens, expected_labels = self._split_tokens_labels([
       (1, 2),
       (2, 1),
       (2, 3),
       (3, 2),
   ])
   with self.cached_session():
     self.assertAllEqual(expected_tokens, tokens.eval())
     self.assertAllEqual(expected_labels, labels.eval())
示例#14
0
 def test_skip_gram_sample_limit_exceeds(self):
   """Tests skip-gram when limit exceeds the length of the input."""
   input_tensor = constant_op.constant([b"foo", b"the", b"quick", b"brown"])
   tokens, labels = text.skip_gram_sample(
       input_tensor, min_skips=1, max_skips=1, start=1, limit=100)
   expected_tokens, expected_labels = self._split_tokens_labels([
       (b"the", b"quick"),
       (b"quick", b"the"),
       (b"quick", b"brown"),
       (b"brown", b"quick"),
   ])
   with self.cached_session():
     self.assertAllEqual(expected_tokens, tokens.eval())
     self.assertAllEqual(expected_labels, labels.eval())
示例#15
0
 def test_skip_gram_sample_start_limit(self):
   """Tests skip-gram over a limited portion of the input."""
   input_tensor = constant_op.constant(
       [b"foo", b"the", b"quick", b"brown", b"bar"])
   tokens, labels = text.skip_gram_sample(
       input_tensor, min_skips=1, max_skips=1, start=1, limit=3)
   expected_tokens, expected_labels = self._split_tokens_labels([
       (b"the", b"quick"),
       (b"quick", b"the"),
       (b"quick", b"brown"),
       (b"brown", b"quick"),
   ])
   with self.cached_session():
     self.assertAllEqual(expected_tokens, tokens.eval())
     self.assertAllEqual(expected_labels, labels.eval())
示例#16
0
 def test_skip_gram_sample_skips_exceed_length(self):
   """Tests skip-gram when min/max_skips exceed length of input."""
   input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
   tokens, labels = text.skip_gram_sample(
       input_tensor, min_skips=100, max_skips=100)
   expected_tokens, expected_labels = self._split_tokens_labels([
       (b"the", b"quick"),
       (b"the", b"brown"),
       (b"quick", b"the"),
       (b"quick", b"brown"),
       (b"brown", b"the"),
       (b"brown", b"quick"),
   ])
   with self.cached_session():
     self.assertAllEqual(expected_tokens, tokens.eval())
     self.assertAllEqual(expected_labels, labels.eval())
示例#17
0
 def test_skip_gram_sample_skips_exceed_length(self):
     """Tests skip-gram when min/max_skips exceed length of input."""
     input_tensor = constant_op.constant([b"the", b"quick", b"brown"])
     tokens, labels = text.skip_gram_sample(input_tensor,
                                            min_skips=100,
                                            max_skips=100)
     expected_tokens, expected_labels = self._split_tokens_labels([
         (b"the", b"quick"),
         (b"the", b"brown"),
         (b"quick", b"the"),
         (b"quick", b"brown"),
         (b"brown", b"the"),
         (b"brown", b"quick"),
     ])
     with self.test_session():
         self.assertAllEqual(expected_tokens, tokens.eval())
         self.assertAllEqual(expected_labels, labels.eval())
示例#18
0
 def test_skip_gram_sample_start_limit(self):
     """Tests skip-gram over a limited portion of the input."""
     input_tensor = constant_op.constant(
         [b"foo", b"the", b"quick", b"brown", b"bar"])
     tokens, labels = text.skip_gram_sample(input_tensor,
                                            min_skips=1,
                                            max_skips=1,
                                            start=1,
                                            limit=3)
     expected_tokens, expected_labels = self._split_tokens_labels([
         (b"the", b"quick"),
         (b"quick", b"the"),
         (b"quick", b"brown"),
         (b"brown", b"quick"),
     ])
     with self.test_session():
         self.assertAllEqual(expected_tokens, tokens.eval())
         self.assertAllEqual(expected_labels, labels.eval())
  def test_skip_gram_sample_errors(self):
    """Tests various errors raised by skip_gram_sample()."""
    input_tensor = constant_op.constant([b"the", b"quick", b"brown"])

    invalid_skips = (
        # min_skips and max_skips must be >= 0.
        (-1, 2),
        (1, -2),
        # min_skips must be <= max_skips.
        (2, 1))
    for min_skips, max_skips in invalid_skips:
      tokens, labels = text.skip_gram_sample(
          input_tensor, min_skips=min_skips, max_skips=max_skips)
      with self.test_session() as sess, self.assertRaises(
          errors.InvalidArgumentError):
        sess.run([tokens, labels])

    # input_tensor must be of rank 1.
    with self.assertRaises(ValueError):
      invalid_tensor = constant_op.constant([[b"the"], [b"quick"], [b"brown"]])
      text.skip_gram_sample(invalid_tensor)

    # vocab_freq_table must be provided if vocab_min_count, vocab_subsampling,
    # or corpus_size is specified.
    dummy_input = constant_op.constant([""])
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input, vocab_freq_table=None, vocab_min_count=1)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input, vocab_freq_table=None, vocab_subsampling=1e-5)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(dummy_input, vocab_freq_table=None, corpus_size=100)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input,
          vocab_freq_table=None,
          vocab_subsampling=1e-5,
          corpus_size=100)

    # vocab_subsampling and corpus_size must both be present or absent.
    dummy_table = lookup.HashTable(
        lookup.KeyValueTensorInitializer([b"foo"], [10]), -1)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input,
          vocab_freq_table=dummy_table,
          vocab_subsampling=None,
          corpus_size=100)
    with self.assertRaises(ValueError):
      text.skip_gram_sample(
          dummy_input,
          vocab_freq_table=dummy_table,
          vocab_subsampling=1e-5,
          corpus_size=None)
示例#20
0
    def test_skip_gram_sample_errors(self):
        """Tests various errors raised by skip_gram_sample()."""
        input_tensor = constant_op.constant([b"the", b"quick", b"brown"])

        invalid_skips = (
            # min_skips and max_skips must be >= 0.
            (-1, 2),
            (1, -2),
            # min_skips must be <= max_skips.
            (2, 1))
        for min_skips, max_skips in invalid_skips:
            tokens, labels = text.skip_gram_sample(input_tensor,
                                                   min_skips=min_skips,
                                                   max_skips=max_skips)
            with self.test_session() as sess, self.assertRaises(
                    errors.InvalidArgumentError):
                sess.run([tokens, labels])

        # input_tensor must be of rank 1.
        with self.assertRaises(ValueError):
            invalid_tensor = constant_op.constant([[b"the"], [b"quick"],
                                                   [b"brown"]])
            text.skip_gram_sample(invalid_tensor)

        # vocab_freq_table must be provided if vocab_min_count, vocab_subsampling,
        # or corpus_size is specified.
        dummy_input = constant_op.constant([""])
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=None,
                                  vocab_min_count=1)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=None,
                                  vocab_subsampling=1e-5)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=None,
                                  corpus_size=100)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=None,
                                  vocab_subsampling=1e-5,
                                  corpus_size=100)

        # vocab_subsampling and corpus_size must both be present or absent.
        dummy_table = lookup.HashTable(
            lookup.KeyValueTensorInitializer([b"foo"], [10]), -1)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=dummy_table,
                                  vocab_subsampling=None,
                                  corpus_size=100)
        with self.assertRaises(ValueError):
            text.skip_gram_sample(dummy_input,
                                  vocab_freq_table=dummy_table,
                                  vocab_subsampling=1e-5,
                                  corpus_size=None)